...

Source file src/k8s.io/client-go/tools/remotecommand/fallback_test.go

Documentation: k8s.io/client-go/tools/remotecommand

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package remotecommand
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/rand"
    23  	"io"
    24  	"net/http"
    25  	"net/http/httptest"
    26  	"net/url"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/stretchr/testify/assert"
    31  	"github.com/stretchr/testify/require"
    32  	"k8s.io/apimachinery/pkg/util/remotecommand"
    33  	"k8s.io/apimachinery/pkg/util/wait"
    34  	"k8s.io/client-go/rest"
    35  )
    36  
    37  func TestFallbackClient_WebSocketPrimarySucceeds(t *testing.T) {
    38  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
    39  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    40  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
    41  		if err != nil {
    42  			w.WriteHeader(http.StatusForbidden)
    43  			return
    44  		}
    45  		defer conns.conn.Close()
    46  		// Loopback the STDIN stream onto the STDOUT stream.
    47  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
    48  		require.NoError(t, err)
    49  	}))
    50  	defer websocketServer.Close()
    51  
    52  	// Now create the fallback client (executor), and point it to the "websocketServer".
    53  	// Must add STDIN and STDOUT query params for the client request.
    54  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
    55  	websocketLocation, err := url.Parse(websocketServer.URL)
    56  	require.NoError(t, err)
    57  	websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
    58  	require.NoError(t, err)
    59  	spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
    60  	require.NoError(t, err)
    61  	// Never fallback, so always use the websocketExecutor, which succeeds against websocket server.
    62  	exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return false })
    63  	require.NoError(t, err)
    64  	// Generate random data, and set it up to stream on STDIN. The data will be
    65  	// returned on the STDOUT buffer.
    66  	randomSize := 1024 * 1024
    67  	randomData := make([]byte, randomSize)
    68  	if _, err := rand.Read(randomData); err != nil {
    69  		t.Errorf("unexpected error reading random data: %v", err)
    70  	}
    71  	var stdout bytes.Buffer
    72  	options := &StreamOptions{
    73  		Stdin:  bytes.NewReader(randomData),
    74  		Stdout: &stdout,
    75  	}
    76  	errorChan := make(chan error)
    77  	go func() {
    78  		// Start the streaming on the WebSocket "exec" client.
    79  		errorChan <- exec.StreamWithContext(context.Background(), *options)
    80  	}()
    81  
    82  	select {
    83  	case <-time.After(wait.ForeverTestTimeout):
    84  		t.Fatalf("expect stream to be closed after connection is closed.")
    85  	case err := <-errorChan:
    86  		if err != nil {
    87  			t.Errorf("unexpected error")
    88  		}
    89  	}
    90  
    91  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
    92  	if err != nil {
    93  		t.Errorf("error reading the stream: %v", err)
    94  		return
    95  	}
    96  	// Check the random data sent on STDIN was the same returned on STDOUT.
    97  	if !bytes.Equal(randomData, data) {
    98  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
    99  	}
   100  }
   101  
   102  func TestFallbackClient_SPDYSecondarySucceeds(t *testing.T) {
   103  	// Create fake SPDY server. Copy received STDIN data back onto STDOUT stream.
   104  	spdyServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   105  		var stdin, stdout bytes.Buffer
   106  		ctx, err := createHTTPStreams(w, req, &StreamOptions{
   107  			Stdin:  &stdin,
   108  			Stdout: &stdout,
   109  		})
   110  		if err != nil {
   111  			w.WriteHeader(http.StatusForbidden)
   112  			return
   113  		}
   114  		defer ctx.conn.Close()
   115  		_, err = io.Copy(ctx.stdoutStream, ctx.stdinStream)
   116  		if err != nil {
   117  			t.Fatalf("error copying STDIN to STDOUT: %v", err)
   118  		}
   119  	}))
   120  	defer spdyServer.Close()
   121  
   122  	spdyLocation, err := url.Parse(spdyServer.URL)
   123  	require.NoError(t, err)
   124  	websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: spdyLocation.Host}, "GET", spdyServer.URL)
   125  	require.NoError(t, err)
   126  	spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: spdyLocation.Host}, "POST", spdyLocation)
   127  	require.NoError(t, err)
   128  	// Always fallback to spdyExecutor, and spdyExecutor succeeds against fake spdy server.
   129  	exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
   130  	require.NoError(t, err)
   131  	// Generate random data, and set it up to stream on STDIN. The data will be
   132  	// returned on the STDOUT buffer.
   133  	randomSize := 1024 * 1024
   134  	randomData := make([]byte, randomSize)
   135  	if _, err := rand.Read(randomData); err != nil {
   136  		t.Errorf("unexpected error reading random data: %v", err)
   137  	}
   138  	var stdout bytes.Buffer
   139  	options := &StreamOptions{
   140  		Stdin:  bytes.NewReader(randomData),
   141  		Stdout: &stdout,
   142  	}
   143  	errorChan := make(chan error)
   144  	go func() {
   145  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   146  	}()
   147  
   148  	select {
   149  	case <-time.After(wait.ForeverTestTimeout):
   150  		t.Fatalf("expect stream to be closed after connection is closed.")
   151  	case err := <-errorChan:
   152  		if err != nil {
   153  			t.Errorf("unexpected error")
   154  		}
   155  	}
   156  
   157  	data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
   158  	if err != nil {
   159  		t.Errorf("error reading the stream: %v", err)
   160  		return
   161  	}
   162  	// Check the random data sent on STDIN was the same returned on STDOUT.
   163  	if !bytes.Equal(randomData, data) {
   164  		t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
   165  	}
   166  }
   167  
   168  func TestFallbackClient_PrimaryAndSecondaryFail(t *testing.T) {
   169  	// Create fake WebSocket server. Copy received STDIN data back onto STDOUT stream.
   170  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   171  		conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
   172  		if err != nil {
   173  			w.WriteHeader(http.StatusForbidden)
   174  			return
   175  		}
   176  		defer conns.conn.Close()
   177  		// Loopback the STDIN stream onto the STDOUT stream.
   178  		_, err = io.Copy(conns.stdoutStream, conns.stdinStream)
   179  		require.NoError(t, err)
   180  	}))
   181  	defer websocketServer.Close()
   182  
   183  	// Now create the fallback client (executor), and point it to the "websocketServer".
   184  	// Must add STDIN and STDOUT query params for the client request.
   185  	websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
   186  	websocketLocation, err := url.Parse(websocketServer.URL)
   187  	require.NoError(t, err)
   188  	websocketExecutor, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
   189  	require.NoError(t, err)
   190  	spdyExecutor, err := NewSPDYExecutor(&rest.Config{Host: websocketLocation.Host}, "POST", websocketLocation)
   191  	require.NoError(t, err)
   192  	// Always fallback to spdyExecutor, but spdyExecutor fails against websocket server.
   193  	exec, err := NewFallbackExecutor(websocketExecutor, spdyExecutor, func(error) bool { return true })
   194  	require.NoError(t, err)
   195  	// Update the websocket executor to request remote command v4, which is unsupported.
   196  	fallbackExec, ok := exec.(*FallbackExecutor)
   197  	assert.True(t, ok, "error casting executor as FallbackExecutor")
   198  	websocketExec, ok := fallbackExec.primary.(*wsStreamExecutor)
   199  	assert.True(t, ok, "error casting executor as websocket executor")
   200  	// Set the attempted subprotocol version to V4; websocket server only accepts V5.
   201  	websocketExec.protocols = []string{remotecommand.StreamProtocolV4Name}
   202  
   203  	// Generate random data, and set it up to stream on STDIN. The data will be
   204  	// returned on the STDOUT buffer.
   205  	randomSize := 1024 * 1024
   206  	randomData := make([]byte, randomSize)
   207  	if _, err := rand.Read(randomData); err != nil {
   208  		t.Errorf("unexpected error reading random data: %v", err)
   209  	}
   210  	var stdout bytes.Buffer
   211  	options := &StreamOptions{
   212  		Stdin:  bytes.NewReader(randomData),
   213  		Stdout: &stdout,
   214  	}
   215  	errorChan := make(chan error)
   216  	go func() {
   217  		errorChan <- exec.StreamWithContext(context.Background(), *options)
   218  	}()
   219  
   220  	select {
   221  	case <-time.After(wait.ForeverTestTimeout):
   222  		t.Fatalf("expect stream to be closed after connection is closed.")
   223  	case err := <-errorChan:
   224  		// Ensure secondary executor returned an error.
   225  		require.Error(t, err)
   226  	}
   227  }
   228  

View as plain text