...

Source file src/k8s.io/client-go/tools/remotecommand/websocket_test.go

Documentation: k8s.io/client-go/tools/remotecommand

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package remotecommand
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/rand"
    23  	"encoding/json"
    24  	"fmt"
    25  	"io"
    26  	"math"
    27  	mrand "math/rand"
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"net/url"
    31  	"reflect"
    32  	"strings"
    33  	"sync"
    34  	"testing"
    35  	"time"
    36  
    37  	gwebsocket "github.com/gorilla/websocket"
    38  
    39  	v1 "k8s.io/api/core/v1"
    40  	apierrors "k8s.io/apimachinery/pkg/api/errors"
    41  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    42  	"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
    43  	"k8s.io/apimachinery/pkg/util/remotecommand"
    44  	"k8s.io/apimachinery/pkg/util/wait"
    45  	"k8s.io/client-go/rest"
    46  	clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
    47  )
    48  
    49  // TestWebSocketClient_LoopbackStdinToStdout returns random data sent on the STDIN channel
    50  // back down the STDOUT channel. A subsequent comparison checks if the data
    51  // sent on the STDIN channel is the same as the data returned on the STDOUT
    52  // channel. This test can be run many times by the "stress" tool to check
    53  // if there is any data which would cause problems with the WebSocket streams.
    54  func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) {
    55  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
    56  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    57  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
    58  		if err != nil {
    59  			t.Fatalf("error on webSocketServerStreams: %v", err)
    60  		}
    61  		defer conns.conn.Close()
    62  		// Loopback the STDIN stream onto the STDOUT stream.
    63  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
    64  		if err != nil {
    65  			t.Fatalf("error copying STDIN to STDOUT: %v", err)
    66  		}
    67  	}))
    68  	defer websocketServer.Close()
    69  
    70  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
    71  	// Must add STDIN and STDOUT query params for the WebSocket client request.
    72  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
    73  	websocketLocation, err := url.Parse(websocketServer.URL)
    74  	if err != nil {
    75  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
    76  	}
    77  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
    78  	if err != nil {
    79  		t.Errorf("unexpected error creating websocket executor: %v", err)
    80  	}
    81  	// Generate random data, and set it up to stream on STDIN. The data will be
    82  	// returned on the STDOUT buffer.
    83  	randomSize := 1024 * 1024
    84  	randomData := make([]byte, randomSize)
    85  	if _, err := rand.Read(randomData); err != nil {
    86  		t.Errorf("unexpected error reading random data: %v", err)
    87  	}
    88  	var stdout bytes.Buffer
    89  	options := &StreamOptions{
    90  		Stdin:  bytes.NewReader(randomData),
    91  		Stdout: &stdout,
    92  	}
    93  	errorChan := make(chan error)
    94  	go func() {
    95  		// Start the streaming on the WebSocket "exec" client.
    96  		errorChan <- exec.StreamWithContext(context.Background(), *options)
    97  	}()
    98  
    99  	select {
   100  	case <-time.After(wait.ForeverTestTimeout):
   101  		t.Fatalf("expect stream to be closed after connection is closed.")
   102  	case err := <-errorChan:
   103  		if err != nil {
   104  			t.Errorf("unexpected error")
   105  		}
   106  		// Validate remote command v5 protocol was negotiated.
   107  		streamExec := exec.(*wsStreamExecutor)
   108  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   109  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   110  		}
   111  	}
   112  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   113  	if err != nil {
   114  		t.Fatalf("error reading the stream: %v", err)
   115  	}
   116  	// Check the random data sent on STDIN was the same returned on STDOUT.
   117  	if !bytes.Equal(randomData, data) {
   118  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   119  	}
   120  }
   121  
   122  // TestWebSocketClient_DifferentBufferSizes runs the previous loopback (STDIN -> STDOUT) test with different
   123  // buffer sizes for reading from the opposite end of the websocket connection (in the websocket server).
   124  func TestWebSocketClient_DifferentBufferSizes(t *testing.T) {
   125  	// 1k, 4k, 64k, and 128k buffer sizes for reading STDIN at websocket server endpoint.
   126  	// The standard buffer size for io.Copy is 32k.
   127  	bufferSizes := []int{1 * 1024, 4 * 1024, 64 * 1024, 128 * 1024}
   128  	for _, bufferSize := range bufferSizes {
   129  		// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
   130  		websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   131  			conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   132  			if err != nil {
   133  				t.Fatalf("error on webSocketServerStreams: %v", err)
   134  			}
   135  			defer conns.conn.Close()
   136  			// Loopback the STDIN stream onto the STDOUT stream, using buffer with size.
   137  			buffer := make([]byte, bufferSize)
   138  			_, err = io.CopyBuffer(conns.stdoutStream, conns.stdinStream, buffer)
   139  			if err != nil {
   140  				t.Fatalf("error copying STDIN to STDOUT: %v", err)
   141  			}
   142  		}))
   143  		defer websocketServer.Close()
   144  
   145  		// Now create the WebSocket client (executor), and point it to the "websocketServer".
   146  		// Must add STDIN and STDOUT query params for the WebSocket client request.
   147  		websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
   148  		websocketLocation, err := url.Parse(websocketServer.URL)
   149  		if err != nil {
   150  			t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   151  		}
   152  		exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   153  		if err != nil {
   154  			t.Errorf("unexpected error creating websocket executor: %v", err)
   155  		}
   156  		// Generate random data, and set it up to stream on STDIN. The data will be
   157  		// returned on the STDOUT buffer.
   158  		randomSize := 1024 * 1024
   159  		randomData := make([]byte, randomSize)
   160  		if _, err := rand.Read(randomData); err != nil {
   161  			t.Errorf("unexpected error reading random data: %v", err)
   162  		}
   163  		var stdout bytes.Buffer
   164  		options := &StreamOptions{
   165  			Stdin:  bytes.NewReader(randomData),
   166  			Stdout: &stdout,
   167  		}
   168  		errorChan := make(chan error)
   169  		go func() {
   170  			// Start the streaming on the WebSocket "exec" client.
   171  			errorChan <- exec.StreamWithContext(context.Background(), *options)
   172  		}()
   173  
   174  		select {
   175  		case <-time.After(wait.ForeverTestTimeout):
   176  			t.Fatalf("expect stream to be closed after connection is closed.")
   177  		case err := <-errorChan:
   178  			if err != nil {
   179  				t.Errorf("unexpected error")
   180  			}
   181  			// Validate remote command v5 protocol was negotiated.
   182  			streamExec := exec.(*wsStreamExecutor)
   183  			if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   184  				t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   185  			}
   186  		}
   187  		data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   188  		if err != nil {
   189  			t.Errorf("error reading the stream: %v", err)
   190  			return
   191  		}
   192  		// Check all the random data sent on STDIN was the same returned on STDOUT.
   193  		if !bytes.Equal(randomData, data) {
   194  			t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   195  		}
   196  	}
   197  }
   198  
   199  // TestWebSocketClient_LoopbackStdinAsPipe uses a pipe to send random data on the STDIN
   200  // channel, then closes the pipe. The fake server simply returns all STDIN data back
   201  // onto the STDOUT channel, and the received data on STDOUT is validated against the
   202  // random data initially sent.
   203  func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) {
   204  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
   205  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   206  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   207  		if err != nil {
   208  			t.Fatalf("error on webSocketServerStreams: %v", err)
   209  		}
   210  		defer conns.conn.Close()
   211  		// Loopback the STDIN stream onto the STDOUT stream.
   212  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
   213  		if err != nil {
   214  			t.Fatalf("error copying STDIN to STDOUT: %v", err)
   215  		}
   216  	}))
   217  	defer websocketServer.Close()
   218  
   219  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   220  	// Must add STDIN and STDOUT query params for the WebSocket client request.
   221  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
   222  	websocketLocation, err := url.Parse(websocketServer.URL)
   223  	if err != nil {
   224  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   225  	}
   226  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   227  	if err != nil {
   228  		t.Errorf("unexpected error creating websocket executor: %v", err)
   229  	}
   230  	// Generate random data, and it will be written on the STDIN pipe. The same
   231  	// data will be returned on the STDOUT channel.
   232  	randomSize := 1024 * 1024
   233  	randomData := make([]byte, randomSize)
   234  	if _, err := rand.Read(randomData); err != nil {
   235  		t.Errorf("unexpected error reading random data: %v", err)
   236  	}
   237  	reader, writer := io.Pipe()
   238  	var stdout bytes.Buffer
   239  	options := &StreamOptions{
   240  		Stdin:  reader,
   241  		Stdout: &stdout,
   242  	}
   243  	errorChan := make(chan error)
   244  	go func() {
   245  		// Start the streaming on the WebSocket "exec" client.
   246  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   247  	}()
   248  	// Write the random data onto the pipe connected to STDIN, then close the pipe.
   249  	_, err = writer.Write(randomData)
   250  	if err != nil {
   251  		t.Fatalf("unable to write random data to STDIN pipe: %v", err)
   252  	}
   253  	writer.Close()
   254  
   255  	select {
   256  	case <-time.After(wait.ForeverTestTimeout):
   257  		t.Fatalf("expect stream to be closed after connection is closed.")
   258  	case err := <-errorChan:
   259  		if err != nil {
   260  			t.Errorf("unexpected error")
   261  		}
   262  		// Validate remote command v5 protocol was negotiated.
   263  		streamExec := exec.(*wsStreamExecutor)
   264  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   265  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   266  		}
   267  	}
   268  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   269  	if err != nil {
   270  		t.Errorf("error reading the stream: %v", err)
   271  		return
   272  	}
   273  	// Check the random data sent on STDIN was the same returned on STDOUT.
   274  	if !bytes.Equal(randomData, data) {
   275  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   276  	}
   277  }
   278  
   279  // TestWebSocketClient_LoopbackStdinToStderr returns random data sent on the STDIN channel
   280  // back down the STDERR channel. A subsequent comparison checks if the data
   281  // sent on the STDIN channel is the same as the data returned on the STDERR
   282  // channel. This test can be run many times by the "stress" tool to check
   283  // if there is any data which would cause problems with the WebSocket streams.
   284  func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) {
   285  	// Create fake WebSocket server. Copy received STDIN data back onto STDERR stream.
   286  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   287  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   288  		if err != nil {
   289  			t.Fatalf("error on webSocketServerStreams: %v", err)
   290  		}
   291  		defer conns.conn.Close()
   292  		// Loopback the STDIN stream onto the STDERR stream.
   293  		_, err = io.Copy(conns.stderrStream, conns.stdinStream)
   294  		if err != nil {
   295  			t.Fatalf("error copying STDIN to STDERR: %v", err)
   296  		}
   297  	}))
   298  	defer websocketServer.Close()
   299  
   300  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   301  	// Must add STDIN and STDERR query params for the WebSocket client request.
   302  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true"
   303  	websocketLocation, err := url.Parse(websocketServer.URL)
   304  	if err != nil {
   305  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   306  	}
   307  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   308  	if err != nil {
   309  		t.Errorf("unexpected error creating websocket executor: %v", err)
   310  	}
   311  	// Generate random data, and set it up to stream on STDIN. The data will be
   312  	// returned on the STDERR buffer.
   313  	randomSize := 1024 * 1024
   314  	randomData := make([]byte, randomSize)
   315  	if _, err := rand.Read(randomData); err != nil {
   316  		t.Errorf("unexpected error reading random data: %v", err)
   317  	}
   318  	var stderr bytes.Buffer
   319  	options := &StreamOptions{
   320  		Stdin:  bytes.NewReader(randomData),
   321  		Stderr: &stderr,
   322  	}
   323  	errorChan := make(chan error)
   324  	go func() {
   325  		// Start the streaming on the WebSocket "exec" client.
   326  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   327  	}()
   328  
   329  	select {
   330  	case <-time.After(wait.ForeverTestTimeout):
   331  		t.Fatalf("expect stream to be closed after connection is closed.")
   332  	case err := <-errorChan:
   333  		if err != nil {
   334  			t.Errorf("unexpected error")
   335  		}
   336  		// Validate remote command v5 protocol was negotiated.
   337  		streamExec := exec.(*wsStreamExecutor)
   338  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   339  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   340  		}
   341  	}
   342  	data, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
   343  	if err != nil {
   344  		t.Errorf("error reading the stream: %v", err)
   345  		return
   346  	}
   347  	// Check the random data sent on STDIN was the same returned on STDERR.
   348  	if !bytes.Equal(randomData, data) {
   349  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   350  	}
   351  }
   352  
   353  // TestWebSocketClient_MultipleReadChannels tests two streams (STDOUT, STDERR) reading from
   354  // the websocket connection at the same time.
   355  func TestWebSocketClient_MultipleReadChannels(t *testing.T) {
   356  	// Create fake WebSocket server, which uses a TeeReader to copy the same data
   357  	// onto the STDOUT stream onto the STDERR stream as well.
   358  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   359  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   360  		if err != nil {
   361  			t.Fatalf("error on webSocketServerStreams: %v", err)
   362  		}
   363  		defer conns.conn.Close()
   364  		// TeeReader copies data read on STDIN onto STDERR.
   365  		stdinReader := io.TeeReader(conns.stdinStream, conns.stderrStream)
   366  		// Also copy STDIN to STDOUT.
   367  		_, err = io.Copy(conns.stdoutStream, stdinReader)
   368  		if err != nil {
   369  			t.Errorf("error copying STDIN to STDOUT: %v", err)
   370  		}
   371  	}))
   372  	defer websocketServer.Close()
   373  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   374  	// Must add stdin, stdout, and stderr query param for the WebSocket client request.
   375  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + "&" + "stderr=true"
   376  	websocketLocation, err := url.Parse(websocketServer.URL)
   377  	if err != nil {
   378  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   379  	}
   380  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   381  	if err != nil {
   382  		t.Errorf("unexpected error creating websocket executor: %v", err)
   383  	}
   384  	// Generate 1MB of random data, and set it up to stream on STDIN. The data will be
   385  	// returned on the STDOUT and STDERR buffers.
   386  	randomSize := 1024 * 1024
   387  	randomData := make([]byte, randomSize)
   388  	if _, err := rand.Read(randomData); err != nil {
   389  		t.Errorf("unexpected error reading random data: %v", err)
   390  	}
   391  	var stdout, stderr bytes.Buffer
   392  	options := &StreamOptions{
   393  		Stdin:  bytes.NewReader(randomData),
   394  		Stdout: &stdout,
   395  		Stderr: &stderr,
   396  	}
   397  	errorChan := make(chan error)
   398  	go func() {
   399  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   400  	}()
   401  
   402  	select {
   403  	case <-time.After(wait.ForeverTestTimeout):
   404  		t.Fatalf("expect stream to be closed after connection is closed.")
   405  	case err := <-errorChan:
   406  		if err != nil {
   407  			t.Errorf("unexpected error: %v", err)
   408  		}
   409  		// Validate remote command v5 protocol was negotiated.
   410  		streamExec := exec.(*wsStreamExecutor)
   411  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   412  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   413  		}
   414  	}
   415  	// Validate the data read from the STDOUT stream is the same as sent on the STDIN stream.
   416  	stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   417  	if err != nil {
   418  		t.Fatalf("error reading the stream: %v", err)
   419  	}
   420  	if !bytes.Equal(stdoutBytes, randomData) {
   421  		t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(randomData))
   422  	}
   423  	// Validate the data read from the STDERR stream is the same as sent on the STDIN stream.
   424  	stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
   425  	if err != nil {
   426  		t.Fatalf("error reading the stream: %v", err)
   427  	}
   428  	if !bytes.Equal(stderrBytes, randomData) {
   429  		t.Errorf("unexpected data received (%d) sent (%d)", len(stderrBytes), len(randomData))
   430  	}
   431  }
   432  
   433  // Returns a random exit code in the range(1-127).
   434  func randomExitCode() int {
   435  	errorCode := mrand.Intn(128)
   436  	if errorCode == 0 {
   437  		errorCode = 1
   438  	}
   439  	return errorCode
   440  }
   441  
   442  // TestWebSocketClient_ErrorStream tests the websocket error stream by hard-coding a
   443  // structured non-zero exit code error from the websocket server to the websocket client.
   444  func TestWebSocketClient_ErrorStream(t *testing.T) {
   445  	expectedExitCode := randomExitCode()
   446  	// Create fake WebSocket server. Returns structured exit code error on error stream.
   447  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   448  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   449  		if err != nil {
   450  			t.Fatalf("error on webSocketServerStreams: %v", err)
   451  		}
   452  		defer conns.conn.Close()
   453  		_, err = io.Copy(conns.stderrStream, conns.stdinStream)
   454  		if err != nil {
   455  			t.Fatalf("error copying STDIN to STDERR: %v", err)
   456  		}
   457  		// Force an non-zero exit code error returned on the error stream.
   458  		err = conns.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
   459  			Status: metav1.StatusFailure,
   460  			Reason: remotecommand.NonZeroExitCodeReason,
   461  			Details: &metav1.StatusDetails{
   462  				Causes: []metav1.StatusCause{
   463  					{
   464  						Type:    remotecommand.ExitCodeCauseType,
   465  						Message: fmt.Sprintf("%d", expectedExitCode),
   466  					},
   467  				},
   468  			},
   469  		}})
   470  		if err != nil {
   471  			t.Fatalf("error writing status: %v", err)
   472  		}
   473  	}))
   474  	defer websocketServer.Close()
   475  
   476  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   477  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true"
   478  	websocketLocation, err := url.Parse(websocketServer.URL)
   479  	if err != nil {
   480  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   481  	}
   482  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   483  	if err != nil {
   484  		t.Errorf("unexpected error creating websocket executor: %v", err)
   485  	}
   486  	randomData := make([]byte, 256)
   487  	if _, err := rand.Read(randomData); err != nil {
   488  		t.Errorf("unexpected error reading random data: %v", err)
   489  	}
   490  	var stderr bytes.Buffer
   491  	options := &StreamOptions{
   492  		Stdin:  bytes.NewReader(randomData),
   493  		Stderr: &stderr,
   494  	}
   495  	errorChan := make(chan error)
   496  	go func() {
   497  		// Start the streaming on the WebSocket "exec" client.
   498  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   499  	}()
   500  
   501  	select {
   502  	case <-time.After(wait.ForeverTestTimeout):
   503  		t.Fatalf("expect stream to be closed after connection is closed.")
   504  	case err := <-errorChan:
   505  		// Validate remote command v5 protocol was negotiated.
   506  		streamExec := exec.(*wsStreamExecutor)
   507  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   508  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   509  		}
   510  		// Expect exit code error on error stream.
   511  		if err == nil {
   512  			t.Errorf("expected error, but received none")
   513  		}
   514  		expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode)
   515  		// Compare expected error with exit code to actual error.
   516  		if expectedError != err.Error() {
   517  			t.Errorf("expected error (%s), got (%s)", expectedError, err)
   518  		}
   519  	}
   520  }
   521  
   522  // fakeTerminalSizeQueue implements TerminalSizeQueue, returning a random set of
   523  // "maxSizes" number of TerminalSizes, storing the TerminalSizes in "sizes" slice.
   524  type fakeTerminalSizeQueue struct {
   525  	maxSizes      int
   526  	terminalSizes []TerminalSize
   527  }
   528  
   529  // newTerminalSizeQueue returns a pointer to a fakeTerminalSizeQueue passing
   530  // "max" number of random TerminalSizes created.
   531  func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue {
   532  	return &fakeTerminalSizeQueue{
   533  		maxSizes:      max,
   534  		terminalSizes: make([]TerminalSize, 0, max),
   535  	}
   536  }
   537  
   538  // Next returns a pointer to the next random TerminalSize, or nil if we have
   539  // already returned "maxSizes" TerminalSizes already. Stores the randomly
   540  // created TerminalSize in "terminalSizes" field for later validation.
   541  func (f *fakeTerminalSizeQueue) Next() *TerminalSize {
   542  	if len(f.terminalSizes) >= f.maxSizes {
   543  		return nil
   544  	}
   545  	size := randomTerminalSize()
   546  	f.terminalSizes = append(f.terminalSizes, size)
   547  	return &size
   548  }
   549  
   550  // randomTerminalSize returns a TerminalSize with random values in the
   551  // range (0-65535) for the fields Width and Height.
   552  func randomTerminalSize() TerminalSize {
   553  	randWidth := uint16(mrand.Intn(int(math.Pow(2, 16))))
   554  	randHeight := uint16(mrand.Intn(int(math.Pow(2, 16))))
   555  	return TerminalSize{
   556  		Width:  randWidth,
   557  		Height: randHeight,
   558  	}
   559  }
   560  
   561  // randReader implements the ReadCloser interface, and it continuously
   562  // returns random data until it is closed. Stores number of random
   563  // bytes generated and returned.
   564  type randReader struct {
   565  	randBytes []byte
   566  	closed    bool
   567  	lock      sync.Mutex
   568  }
   569  
   570  // Read implements the Reader interface filling the passed buffer with
   571  // random data, returning the number of bytes filled and an error
   572  // if one occurs. Return 0 and EOF if the randReader has been closed.
   573  func (r *randReader) Read(b []byte) (int, error) {
   574  	r.lock.Lock()
   575  	defer r.lock.Unlock()
   576  	if r.closed {
   577  		return 0, io.EOF
   578  	}
   579  	n, err := rand.Read(b)
   580  	c := bytes.Clone(b)
   581  	r.randBytes = append(r.randBytes, c...)
   582  	return n, err
   583  }
   584  
   585  // Close implements the Closer interface, setting the close field true.
   586  // Further calls to Read() after Close() will return 0, EOF. Returns
   587  // nil error.
   588  func (r *randReader) Close() (err error) {
   589  	r.lock.Lock()
   590  	defer r.lock.Unlock()
   591  	r.closed = true
   592  	return nil
   593  }
   594  
   595  // TestWebSocketClient_MultipleWriteChannels tests two streams (STDIN, TTY resize) writing to the
   596  // websocket connection at the same time to exercise the connection write lock.
   597  func TestWebSocketClient_MultipleWriteChannels(t *testing.T) {
   598  	// Create the fake terminal size queue and the actualTerminalSizes which
   599  	// will be received at the opposite websocket endpoint.
   600  	numSizeQueue := 10000
   601  	sizeQueue := newTerminalSizeQueue(numSizeQueue)
   602  	actualTerminalSizes := make([]TerminalSize, 0, numSizeQueue)
   603  	// Create ReadCloser sending random data on STDIN stream over websocket connection.
   604  	stdinReader := randReader{randBytes: []byte{}, closed: false}
   605  	// Create fake WebSocket server, which will receive concurrently the STDIN stream as
   606  	// well as the resize stream (TerminalSizes). Store the TerminalSize data from the resize
   607  	// stream for subsequent validation.
   608  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   609  		var wg sync.WaitGroup
   610  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   611  		if err != nil {
   612  			t.Fatalf("error on webSocketServerStreams: %v", err)
   613  		}
   614  		defer conns.conn.Close()
   615  		// Create goroutine to loopback the STDIN stream onto the STDOUT stream.
   616  		wg.Add(1)
   617  		go func() {
   618  			_, err := io.Copy(conns.stdoutStream, conns.stdinStream)
   619  			if err != nil {
   620  				t.Errorf("error copying STDIN to STDOUT: %v", err)
   621  			}
   622  			wg.Done()
   623  		}()
   624  		// Read the terminal resize requests, storing them in actualTerminalSizes
   625  		for i := 0; i < numSizeQueue; i++ {
   626  			actualTerminalSize := <-conns.resizeChan
   627  			actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize)
   628  		}
   629  		stdinReader.Close() // Stops the random STDIN stream generation
   630  		wg.Wait()           // Wait for all bytes copied from STDIN to STDOUT
   631  	}))
   632  	defer websocketServer.Close()
   633  	// Now create the WebSocket client (executor), and point it to the "websocketServer".
   634  	// Must add stdin, stdout, and TTY query param for the WebSocket client request.
   635  	websocketServer.URL = websocketServer.URL + "?" + "tty=true" + "&" + "stdin=true" + "&" + "stdout=true"
   636  	websocketLocation, err := url.Parse(websocketServer.URL)
   637  	if err != nil {
   638  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   639  	}
   640  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   641  	if err != nil {
   642  		t.Errorf("unexpected error creating websocket executor: %v", err)
   643  	}
   644  	var stdout bytes.Buffer
   645  	options := &StreamOptions{
   646  		Stdin:             &stdinReader,
   647  		Stdout:            &stdout,
   648  		Tty:               true,
   649  		TerminalSizeQueue: sizeQueue,
   650  	}
   651  	errorChan := make(chan error)
   652  	go func() {
   653  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   654  	}()
   655  
   656  	select {
   657  	case <-time.After(wait.ForeverTestTimeout):
   658  		t.Fatalf("expect stream to be closed after connection is closed.")
   659  	case err := <-errorChan:
   660  		if err != nil {
   661  			t.Errorf("unexpected error: %v", err)
   662  		}
   663  		// Validate remote command v5 protocol was negotiated.
   664  		streamExec := exec.(*wsStreamExecutor)
   665  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   666  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   667  		}
   668  	}
   669  	// Check the random data sent on STDIN was the same returned on STDOUT *and*
   670  	// that a minimum amount of random data was sent and received, ensuring concurrency.
   671  	stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   672  	if err != nil {
   673  		t.Fatalf("error reading the stream: %v", err)
   674  	}
   675  	if len(stdoutBytes) == 0 {
   676  		t.Errorf("No STDOUT bytes processed before resize stream finished: %d", len(stdoutBytes))
   677  	}
   678  	if !bytes.Equal(stdoutBytes, stdinReader.randBytes) {
   679  		t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(stdinReader.randBytes))
   680  	}
   681  	// Validate the random TerminalSizes sent on the resize stream are the same
   682  	// as the actual TerminalSizes received at the websocket server.
   683  	if len(actualTerminalSizes) != numSizeQueue {
   684  		t.Errorf("expected received terminal size window (%d), got (%d)",
   685  			numSizeQueue, len(actualTerminalSizes))
   686  	}
   687  	for i, actual := range actualTerminalSizes {
   688  		expected := sizeQueue.terminalSizes[i]
   689  		if !reflect.DeepEqual(expected, actual) {
   690  			t.Errorf("expected terminal resize window %v, got %v", expected, actual)
   691  		}
   692  	}
   693  }
   694  
   695  // TestWebSocketClient_ProtocolVersions validates that remote command subprotocol versions V2-V4
   696  // (V5 is already tested elsewhere) can be negotiated.
   697  func TestWebSocketClient_ProtocolVersions(t *testing.T) {
   698  	// Create a raw websocket server that accepts V2-V4 versions of
   699  	// the remote command subprotocol.
   700  	var upgrader = gwebsocket.Upgrader{
   701  		CheckOrigin: func(r *http.Request) bool {
   702  			return true // Accepting all requests
   703  		},
   704  		Subprotocols: []string{
   705  			remotecommand.StreamProtocolV4Name,
   706  			remotecommand.StreamProtocolV3Name,
   707  			remotecommand.StreamProtocolV2Name,
   708  		},
   709  	}
   710  	// Upgrade a raw websocket server connection.
   711  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   712  		conn, err := upgrader.Upgrade(w, req, nil)
   713  		if err != nil {
   714  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
   715  		}
   716  		defer conn.Close()
   717  	}))
   718  	defer websocketServer.Close()
   719  
   720  	// Set up the websocket client with the STDOUT stream.
   721  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   722  	websocketLocation, err := url.Parse(websocketServer.URL)
   723  	if err != nil {
   724  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   725  	}
   726  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   727  	if err != nil {
   728  		t.Errorf("unexpected error creating websocket executor: %v", err)
   729  	}
   730  	// Iterate through previous remote command protocol versions, validating the
   731  	// requested protocol version is the one that is negotiated.
   732  	versions := []string{
   733  		remotecommand.StreamProtocolV4Name,
   734  		remotecommand.StreamProtocolV3Name,
   735  		remotecommand.StreamProtocolV2Name,
   736  	}
   737  	for _, requestedVersion := range versions {
   738  		streamExec := exec.(*wsStreamExecutor)
   739  		streamExec.protocols = []string{requestedVersion}
   740  		var stdout bytes.Buffer
   741  		options := &StreamOptions{
   742  			Stdout: &stdout,
   743  		}
   744  		errorChan := make(chan error)
   745  		go func() {
   746  			// Start the streaming on the WebSocket "exec" client.
   747  			errorChan <- exec.StreamWithContext(context.Background(), *options)
   748  		}()
   749  
   750  		select {
   751  		case <-time.After(wait.ForeverTestTimeout):
   752  			t.Fatalf("expect stream to be closed after connection is closed.")
   753  		case <-errorChan:
   754  			// Validate remote command protocol requestedVersion was negotiated.
   755  			streamExec := exec.(*wsStreamExecutor)
   756  			if requestedVersion != streamExec.negotiated {
   757  				t.Fatalf("expected protocol version (%s), got (%s)", requestedVersion, streamExec.negotiated)
   758  			}
   759  		}
   760  	}
   761  }
   762  
   763  // TestWebSocketClient_BadHandshake tests that a "bad handshake" error occurs when
   764  // the WebSocketExecutor attempts to upgrade the connection to a subprotocol version
   765  // (V4) that is not supported by the websocket server (only supports V5).
   766  func TestWebSocketClient_BadHandshake(t *testing.T) {
   767  	// Create fake WebSocket server (supports V5 subprotocol).
   768  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   769  		// Bad handshake means websocket server will not completely initialize.
   770  		_, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   771  		if err == nil {
   772  			t.Fatalf("expected error, but received none.")
   773  		}
   774  		if !strings.Contains(err.Error(), "websocket server finished before becoming ready") {
   775  			t.Errorf("expected websocket server error, but got: %v", err)
   776  		}
   777  	}))
   778  	defer websocketServer.Close()
   779  
   780  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   781  	websocketLocation, err := url.Parse(websocketServer.URL)
   782  	if err != nil {
   783  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   784  	}
   785  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   786  	if err != nil {
   787  		t.Errorf("unexpected error creating websocket executor: %v", err)
   788  	}
   789  	streamExec := exec.(*wsStreamExecutor)
   790  	// Set the attempted subprotocol version to V4; websocket server only accepts V5.
   791  	streamExec.protocols = []string{remotecommand.StreamProtocolV4Name}
   792  
   793  	var stdout bytes.Buffer
   794  	options := &StreamOptions{
   795  		Stdout: &stdout,
   796  	}
   797  	errorChan := make(chan error)
   798  	go func() {
   799  		// Start the streaming on the WebSocket "exec" client.
   800  		errorChan <- streamExec.StreamWithContext(context.Background(), *options)
   801  	}()
   802  
   803  	select {
   804  	case <-time.After(wait.ForeverTestTimeout):
   805  		t.Fatalf("expect stream to be closed after connection is closed.")
   806  	case err := <-errorChan:
   807  		// Expecting unable to upgrade connection -- "bad handshake" error.
   808  		if err == nil {
   809  			t.Errorf("expected error but received none")
   810  		}
   811  		if !strings.Contains(err.Error(), "bad handshake") {
   812  			t.Errorf("expected bad handshake error, got (%s)", err)
   813  		}
   814  	}
   815  }
   816  
   817  // TestWebSocketClient_HeartbeatTimeout tests the heartbeat by forcing a
   818  // timeout by setting the ping period greater than the deadline.
   819  func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
   820  	blockRequestCtx, unblockRequest := context.WithCancel(context.Background())
   821  	defer unblockRequest()
   822  	// Create fake WebSocket server which blocks.
   823  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   824  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   825  		if err != nil {
   826  			t.Fatalf("error on webSocketServerStreams: %v", err)
   827  		}
   828  		defer conns.conn.Close()
   829  		<-blockRequestCtx.Done()
   830  	}))
   831  	defer websocketServer.Close()
   832  	// Create websocket client connecting to fake server.
   833  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true"
   834  	websocketLocation, err := url.Parse(websocketServer.URL)
   835  	if err != nil {
   836  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   837  	}
   838  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   839  	if err != nil {
   840  		t.Errorf("unexpected error creating websocket executor: %v", err)
   841  	}
   842  	streamExec := exec.(*wsStreamExecutor)
   843  	// Ping period is greater than the ping deadline, forcing the timeout to fire.
   844  	pingPeriod := wait.ForeverTestTimeout // this lets the heartbeat deadline expire without renewing it
   845  	pingDeadline := time.Second           // this gives setup 1 second to establish streams
   846  	streamExec.heartbeatPeriod = pingPeriod
   847  	streamExec.heartbeatDeadline = pingDeadline
   848  	// Send some random data to the websocket server through STDIN.
   849  	randomData := make([]byte, 128)
   850  	if _, err := rand.Read(randomData); err != nil {
   851  		t.Errorf("unexpected error reading random data: %v", err)
   852  	}
   853  	options := &StreamOptions{
   854  		Stdin: bytes.NewReader(randomData),
   855  	}
   856  	errorChan := make(chan error)
   857  	go func() {
   858  		// Start the streaming on the WebSocket "exec" client.
   859  		errorChan <- streamExec.StreamWithContext(context.Background(), *options)
   860  	}()
   861  
   862  	select {
   863  	case <-time.After(wait.ForeverTestTimeout):
   864  		t.Fatalf("expected heartbeat timeout, got none.")
   865  	case err := <-errorChan:
   866  		// Expecting heartbeat timeout error.
   867  		if err == nil {
   868  			t.Fatalf("expected error but received none")
   869  		}
   870  		if !strings.Contains(err.Error(), "i/o timeout") {
   871  			t.Errorf("expected heartbeat timeout error, got (%s)", err)
   872  		}
   873  		// Validate remote command v5 protocol was negotiated.
   874  		streamExec := exec.(*wsStreamExecutor)
   875  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   876  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   877  		}
   878  	}
   879  }
   880  
   881  // TestWebSocketClient_TextMessageTypeError tests when the wrong message type is returned
   882  // from the other websocket endpoint. Remote command protocols use "BinaryMessage", but
   883  // this test hard-codes returning a "TextMessage".
   884  func TestWebSocketClient_TextMessageTypeError(t *testing.T) {
   885  	var upgrader = gwebsocket.Upgrader{
   886  		CheckOrigin: func(r *http.Request) bool {
   887  			return true // Accepting all requests
   888  		},
   889  		Subprotocols: []string{remotecommand.StreamProtocolV5Name},
   890  	}
   891  	// Upgrade a raw websocket server connection. Returns wrong message type "TextMessage".
   892  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   893  		conn, err := upgrader.Upgrade(w, req, nil)
   894  		if err != nil {
   895  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
   896  		}
   897  		defer conn.Close()
   898  		msg := []byte("test message with wrong message type.")
   899  		stdOutMsg := append([]byte{remotecommand.StreamStdOut}, msg...)
   900  		// Wrong message type "TextMessage".
   901  		err = conn.WriteMessage(gwebsocket.TextMessage, stdOutMsg)
   902  		if err != nil {
   903  			t.Fatalf("error writing text message to websocket: %v", err)
   904  		}
   905  
   906  	}))
   907  	defer websocketServer.Close()
   908  
   909  	// Set up the websocket client with the STDOUT stream.
   910  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   911  	websocketLocation, err := url.Parse(websocketServer.URL)
   912  	if err != nil {
   913  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   914  	}
   915  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   916  	if err != nil {
   917  		t.Errorf("unexpected error creating websocket executor: %v", err)
   918  	}
   919  	var stdout bytes.Buffer
   920  	options := &StreamOptions{
   921  		Stdout: &stdout,
   922  	}
   923  	errorChan := make(chan error)
   924  	go func() {
   925  		// Start the streaming on the WebSocket "exec" client.
   926  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   927  	}()
   928  
   929  	select {
   930  	case <-time.After(wait.ForeverTestTimeout):
   931  		t.Fatalf("expect stream to be closed after connection is closed.")
   932  	case err := <-errorChan:
   933  		// Expecting bad message type error.
   934  		if err == nil {
   935  			t.Fatalf("expected error but received none")
   936  		}
   937  		if !strings.Contains(err.Error(), "unexpected message type") {
   938  			t.Errorf("expected bad message type error, got (%s)", err)
   939  		}
   940  		// Validate remote command v5 protocol was negotiated.
   941  		streamExec := exec.(*wsStreamExecutor)
   942  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
   943  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
   944  		}
   945  	}
   946  }
   947  
   948  // TestWebSocketClient_EmptyMessageHandled tests that the error of a completely empty message
   949  // is handled correctly. If the message is completely empty, the initial read of the stream id
   950  // should fail (followed by cleanup).
   951  func TestWebSocketClient_EmptyMessageHandled(t *testing.T) {
   952  	var upgrader = gwebsocket.Upgrader{
   953  		CheckOrigin: func(r *http.Request) bool {
   954  			return true // Accepting all requests
   955  		},
   956  		Subprotocols: []string{remotecommand.StreamProtocolV5Name},
   957  	}
   958  	// Upgrade a raw websocket server connection. Returns wrong message type "TextMessage".
   959  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   960  		conn, err := upgrader.Upgrade(w, req, nil)
   961  		if err != nil {
   962  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
   963  		}
   964  		defer conn.Close()
   965  		// Send completely empty message, including missing initial stream id.
   966  		conn.WriteMessage(gwebsocket.BinaryMessage, []byte{}) //nolint:errcheck
   967  	}))
   968  	defer websocketServer.Close()
   969  
   970  	// Set up the websocket client with the STDOUT stream.
   971  	websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
   972  	websocketLocation, err := url.Parse(websocketServer.URL)
   973  	if err != nil {
   974  		t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
   975  	}
   976  	exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   977  	if err != nil {
   978  		t.Errorf("unexpected error creating websocket executor: %v", err)
   979  	}
   980  	var stdout bytes.Buffer
   981  	options := &StreamOptions{
   982  		Stdout: &stdout,
   983  	}
   984  	errorChan := make(chan error)
   985  	go func() {
   986  		// Start the streaming on the WebSocket "exec" client.
   987  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   988  	}()
   989  
   990  	select {
   991  	case <-time.After(wait.ForeverTestTimeout):
   992  		t.Fatalf("expect stream to be closed after connection is closed.")
   993  	case err := <-errorChan:
   994  		// Expecting error reading initial stream id.
   995  		if err == nil {
   996  			t.Fatalf("expected error but received none")
   997  		}
   998  		if !strings.Contains(err.Error(), "read stream id") {
   999  			t.Errorf("expected error reading stream id, got (%s)", err)
  1000  		}
  1001  		// Validate remote command v5 protocol was negotiated.
  1002  		streamExec := exec.(*wsStreamExecutor)
  1003  		if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
  1004  			t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
  1005  		}
  1006  	}
  1007  }
  1008  
  1009  func TestWebSocketClient_ExecutorErrors(t *testing.T) {
  1010  	// Invalid config causes transport creation error in websocket executor constructor.
  1011  	config := rest.Config{
  1012  		ExecProvider: &clientcmdapi.ExecConfig{},
  1013  		AuthProvider: &clientcmdapi.AuthProviderConfig{},
  1014  	}
  1015  	_, err := NewWebSocketExecutor(&config, "GET", "http://localhost")
  1016  	if err == nil {
  1017  		t.Errorf("expecting executor constructor error, but received none.")
  1018  	} else if !strings.Contains(err.Error(), "error creating websocket transports") {
  1019  		t.Errorf("expecting error creating transports, got (%s)", err.Error())
  1020  	}
  1021  	// Verify that a nil context will cause an error in StreamWithContext
  1022  	exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost")
  1023  	if err != nil {
  1024  		t.Errorf("unexpected error creating websocket executor: %v", err)
  1025  	}
  1026  	errorChan := make(chan error)
  1027  	go func() {
  1028  		// Start the streaming on the WebSocket "exec" client.
  1029  		var ctx context.Context
  1030  		errorChan <- exec.StreamWithContext(ctx, StreamOptions{})
  1031  	}()
  1032  
  1033  	select {
  1034  	case <-time.After(wait.ForeverTestTimeout):
  1035  		t.Fatalf("expect stream to be closed after connection is closed.")
  1036  	case err := <-errorChan:
  1037  		// Expecting error with nil context.
  1038  		if err == nil {
  1039  			t.Fatalf("expected error but received none")
  1040  		}
  1041  		if !strings.Contains(err.Error(), "nil Context") {
  1042  			t.Errorf("expected nil context error, got (%s)", err)
  1043  		}
  1044  	}
  1045  }
  1046  
  1047  func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
  1048  	var upgrader = gwebsocket.Upgrader{
  1049  		CheckOrigin: func(r *http.Request) bool {
  1050  			return true // Accepting all requests
  1051  		},
  1052  	}
  1053  	// Upgrade a raw websocket server connection, which automatically responds to Ping.
  1054  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
  1055  		conn, err := upgrader.Upgrade(w, req, nil)
  1056  		if err != nil {
  1057  			t.Fatalf("unable to upgrade to create websocket connection: %v", err)
  1058  		}
  1059  		defer conn.Close()
  1060  		for {
  1061  			_, _, err := conn.ReadMessage()
  1062  			if err != nil {
  1063  				break
  1064  			}
  1065  		}
  1066  	}))
  1067  	defer websocketServer.Close()
  1068  	// Create a raw websocket client, connecting to the websocket server.
  1069  	url := strings.ReplaceAll(websocketServer.URL, "http", "ws")
  1070  	client, _, err := gwebsocket.DefaultDialer.Dial(url, nil)
  1071  	if err != nil {
  1072  		t.Fatalf("dial: %v", err)
  1073  	}
  1074  	defer client.Close()
  1075  	// Create a heartbeat using the client websocket connection, and start it.
  1076  	// "period" is less than "deadline", so ping/pong heartbeat will succceed.
  1077  	var expectedMsg = "test heartbeat message"
  1078  	var period = 100 * time.Millisecond
  1079  	var deadline = 200 * time.Millisecond
  1080  	heartbeat := newHeartbeat(client, period, deadline)
  1081  	heartbeat.setMessage(expectedMsg)
  1082  	// Add a channel to the handler to retrieve the "pong" message.
  1083  	pongMsgCh := make(chan string)
  1084  	pongHandler := heartbeat.conn.PongHandler()
  1085  	heartbeat.conn.SetPongHandler(func(msg string) error {
  1086  		pongMsgCh <- msg
  1087  		return pongHandler(msg)
  1088  	})
  1089  	go heartbeat.start()
  1090  
  1091  	var wg sync.WaitGroup
  1092  	wg.Add(1)
  1093  	go func() {
  1094  		defer wg.Done()
  1095  		for {
  1096  			_, _, err := client.ReadMessage()
  1097  			if err != nil {
  1098  				t.Logf("client err reading message: %v", err)
  1099  				return
  1100  			}
  1101  		}
  1102  	}()
  1103  
  1104  	select {
  1105  	case actualMsg := <-pongMsgCh:
  1106  		close(heartbeat.closer)
  1107  		// Validate the received pong message is the same as sent in ping.
  1108  		if expectedMsg != actualMsg {
  1109  			t.Errorf("expected received pong message (%s), got (%s)", expectedMsg, actualMsg)
  1110  		}
  1111  	case <-time.After(period * 4):
  1112  		// This case should not happen.
  1113  		close(heartbeat.closer)
  1114  		t.Errorf("unexpected heartbeat timeout")
  1115  	}
  1116  	wg.Wait()
  1117  }
  1118  
  1119  func TestLateStreamCreation(t *testing.T) {
  1120  	c := newWSStreamCreator(nil)
  1121  	c.closeAllStreamReaders(nil)
  1122  	if err := c.setStream(0, nil); err == nil {
  1123  		t.Fatal("expected error adding stream after closeAllStreamReaders")
  1124  	}
  1125  }
  1126  
  1127  func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
  1128  	// Validate Stream functions.
  1129  	c := newWSStreamCreator(nil)
  1130  	headers := http.Header{}
  1131  	headers.Set(v1.StreamType, v1.StreamTypeStdin)
  1132  	s, err := c.CreateStream(headers)
  1133  	if err != nil {
  1134  		t.Errorf("unexpected stream creation error: %v", err)
  1135  	}
  1136  	expectedStreamID := uint32(remotecommand.StreamStdIn)
  1137  	actualStreamID := s.Identifier()
  1138  	if expectedStreamID != actualStreamID {
  1139  		t.Errorf("expecting stream id (%d), got (%d)", expectedStreamID, actualStreamID)
  1140  	}
  1141  	actualHeaders := s.Headers()
  1142  	if !reflect.DeepEqual(headers, actualHeaders) {
  1143  		t.Errorf("expecting stream headers (%v), got (%v)", headers, actualHeaders)
  1144  	}
  1145  	// Validate stream reset does not return error.
  1146  	err = s.Reset()
  1147  	if err != nil {
  1148  		t.Errorf("unexpected error in stream reset: %v", err)
  1149  	}
  1150  	// Validate close with nil connection is an error.
  1151  	err = s.Close()
  1152  	if err == nil {
  1153  		t.Errorf("expecting stream Close error, but received none")
  1154  	}
  1155  	if !strings.Contains(err.Error(), "Close() on already closed stream") {
  1156  		t.Errorf("expected stream close error, got (%s)", err)
  1157  	}
  1158  	// Validate write with nil connection is an error.
  1159  	n, err := s.Write([]byte("not written"))
  1160  	if n != 0 {
  1161  		t.Errorf("expected zero bytes written, wrote (%d) instead", n)
  1162  	}
  1163  	if err == nil {
  1164  		t.Errorf("expecting stream Write error, but received none")
  1165  	}
  1166  	if !strings.Contains(err.Error(), "write on closed stream") {
  1167  		t.Errorf("expected stream write error, got (%s)", err)
  1168  	}
  1169  	// Validate CreateStream errors -- unknown stream
  1170  	headers = http.Header{}
  1171  	headers.Set(v1.StreamType, "UNKNOWN")
  1172  	_, err = c.CreateStream(headers)
  1173  	if err == nil {
  1174  		t.Errorf("expecting CreateStream error, but received none")
  1175  	} else if !strings.Contains(err.Error(), "unknown stream type") {
  1176  		t.Errorf("expecting unknown stream type error, got (%s)", err.Error())
  1177  	}
  1178  	// Validate CreateStream errors -- duplicate stream
  1179  	headers.Set(v1.StreamType, v1.StreamTypeError)
  1180  	c.streams[remotecommand.StreamErr] = &stream{}
  1181  	_, err = c.CreateStream(headers)
  1182  	if err == nil {
  1183  		t.Errorf("expecting CreateStream error, but received none")
  1184  	} else if !strings.Contains(err.Error(), "duplicate stream") {
  1185  		t.Errorf("expecting duplicate stream error, got (%s)", err.Error())
  1186  	}
  1187  }
  1188  
  1189  // options contains details about which streams are required for
  1190  // remote command execution.
  1191  type options struct {
  1192  	stdin  bool
  1193  	stdout bool
  1194  	stderr bool
  1195  	tty    bool
  1196  }
  1197  
  1198  // Translates query params in request into options struct.
  1199  func streamOptionsFromRequest(req *http.Request) *options {
  1200  	query := req.URL.Query()
  1201  	tty := query.Get("tty") == "true"
  1202  	stdin := query.Get("stdin") == "true"
  1203  	stdout := query.Get("stdout") == "true"
  1204  	stderr := query.Get("stderr") == "true"
  1205  	return &options{
  1206  		stdin:  stdin,
  1207  		stdout: stdout,
  1208  		stderr: stderr,
  1209  		tty:    tty,
  1210  	}
  1211  }
  1212  
  1213  // websocketStreams contains the WebSocket connection and streams from a server.
  1214  type websocketStreams struct {
  1215  	conn         io.Closer
  1216  	stdinStream  io.ReadCloser
  1217  	stdoutStream io.WriteCloser
  1218  	stderrStream io.WriteCloser
  1219  	writeStatus  func(status *apierrors.StatusError) error
  1220  	resizeStream io.ReadCloser
  1221  	resizeChan   chan TerminalSize
  1222  	tty          bool
  1223  }
  1224  
  1225  // Create WebSocket server streams to respond to a WebSocket client. Creates the streams passed
  1226  // in the stream options.
  1227  func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) {
  1228  	conn, err := createWebSocketStreams(req, w, opts)
  1229  	if err != nil {
  1230  		return nil, err
  1231  	}
  1232  
  1233  	if conn.resizeStream != nil {
  1234  		conn.resizeChan = make(chan TerminalSize)
  1235  		go handleResizeEvents(req.Context(), conn.resizeStream, conn.resizeChan)
  1236  	}
  1237  
  1238  	return conn, nil
  1239  }
  1240  
  1241  // Read terminal resize events off of passed stream and queue into passed channel.
  1242  func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- TerminalSize) {
  1243  	defer close(channel)
  1244  
  1245  	decoder := json.NewDecoder(stream)
  1246  	for {
  1247  		size := TerminalSize{}
  1248  		if err := decoder.Decode(&size); err != nil {
  1249  			break
  1250  		}
  1251  
  1252  		select {
  1253  		case channel <- size:
  1254  		case <-ctx.Done():
  1255  			// To avoid leaking this routine, exit if the http request finishes. This path
  1256  			// would generally be hit if starting the process fails and nothing is started to
  1257  			// ingest these resize events.
  1258  			return
  1259  		}
  1260  	}
  1261  }
  1262  
  1263  // createChannels returns the standard channel types for a shell connection (STDIN 0, STDOUT 1, STDERR 2)
  1264  // along with the approximate duplex value. It also creates the error (3) and resize (4) channels.
  1265  func createChannels(opts *options) []wsstream.ChannelType {
  1266  	// open the requested channels, and always open the error channel
  1267  	channels := make([]wsstream.ChannelType, 5)
  1268  	channels[remotecommand.StreamStdIn] = readChannel(opts.stdin)
  1269  	channels[remotecommand.StreamStdOut] = writeChannel(opts.stdout)
  1270  	channels[remotecommand.StreamStdErr] = writeChannel(opts.stderr)
  1271  	channels[remotecommand.StreamErr] = wsstream.WriteChannel
  1272  	channels[remotecommand.StreamResize] = wsstream.ReadChannel
  1273  	return channels
  1274  }
  1275  
  1276  // readChannel returns wsstream.ReadChannel if real is true, or wsstream.IgnoreChannel.
  1277  func readChannel(real bool) wsstream.ChannelType {
  1278  	if real {
  1279  		return wsstream.ReadChannel
  1280  	}
  1281  	return wsstream.IgnoreChannel
  1282  }
  1283  
  1284  // writeChannel returns wsstream.WriteChannel if real is true, or wsstream.IgnoreChannel.
  1285  func writeChannel(real bool) wsstream.ChannelType {
  1286  	if real {
  1287  		return wsstream.WriteChannel
  1288  	}
  1289  	return wsstream.IgnoreChannel
  1290  }
  1291  
  1292  // createWebSocketStreams returns a "channels" struct containing the websocket connection and
  1293  // streams needed to perform an exec or an attach.
  1294  func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) {
  1295  	channels := createChannels(opts)
  1296  	conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
  1297  		remotecommand.StreamProtocolV5Name: {
  1298  			Binary:   true,
  1299  			Channels: channels,
  1300  		},
  1301  	})
  1302  	conn.SetIdleTimeout(4 * time.Hour)
  1303  	// Opening the connection responds to WebSocket client, negotiating
  1304  	// the WebSocket upgrade connection and the subprotocol.
  1305  	_, streams, err := conn.Open(w, req)
  1306  	if err != nil {
  1307  		return nil, err
  1308  	}
  1309  
  1310  	// Send an empty message to the lowest writable channel to notify the client the connection is established
  1311  	//nolint:errcheck
  1312  	switch {
  1313  	case opts.stdout:
  1314  		streams[remotecommand.StreamStdOut].Write([]byte{})
  1315  	case opts.stderr:
  1316  		streams[remotecommand.StreamStdErr].Write([]byte{})
  1317  	default:
  1318  		streams[remotecommand.StreamErr].Write([]byte{})
  1319  	}
  1320  
  1321  	wsStreams := &websocketStreams{
  1322  		conn:         conn,
  1323  		stdinStream:  streams[remotecommand.StreamStdIn],
  1324  		stdoutStream: streams[remotecommand.StreamStdOut],
  1325  		stderrStream: streams[remotecommand.StreamStdErr],
  1326  		tty:          opts.tty,
  1327  		resizeStream: streams[remotecommand.StreamResize],
  1328  	}
  1329  
  1330  	wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error {
  1331  		return func(status *apierrors.StatusError) error {
  1332  			bs, err := json.Marshal(status.Status())
  1333  			if err != nil {
  1334  				return err
  1335  			}
  1336  			_, err = stream.Write(bs)
  1337  			return err
  1338  		}
  1339  	}(streams[remotecommand.StreamErr])
  1340  
  1341  	return wsStreams, nil
  1342  }
  1343  

View as plain text