...

Source file src/k8s.io/client-go/transport/websocket/roundtripper_test.go

Documentation: k8s.io/client-go/transport/websocket

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package websocket
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"net/url"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  
    32  	"k8s.io/apimachinery/pkg/util/httpstream"
    33  	"k8s.io/apimachinery/pkg/util/httpstream/wsstream"
    34  	"k8s.io/apimachinery/pkg/util/remotecommand"
    35  	restclient "k8s.io/client-go/rest"
    36  )
    37  
    38  func TestWebSocketRoundTripper_RoundTripperSucceeds(t *testing.T) {
    39  	// Create fake WebSocket server.
    40  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    41  		conns, err := webSocketServerStreams(req, w)
    42  		if err != nil {
    43  			t.Fatalf("error on webSocketServerStreams: %v", err)
    44  		}
    45  		defer conns.conn.Close()
    46  	}))
    47  	defer websocketServer.Close()
    48  
    49  	// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
    50  	websocketLocation, err := url.Parse(websocketServer.URL)
    51  	require.NoError(t, err)
    52  	req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
    53  	require.NoError(t, err)
    54  	rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
    55  	require.NoError(t, err)
    56  	requestedProtocol := remotecommand.StreamProtocolV5Name
    57  	req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
    58  	_, err = rt.RoundTrip(req)
    59  	require.NoError(t, err)
    60  	// WebSocket Connection is stored in websocket RoundTripper.
    61  	// Compare the expected negotiated subprotocol with the actual subprotocol.
    62  	actualProtocol := wsRt.Connection().Subprotocol()
    63  	assert.Equal(t, requestedProtocol, actualProtocol)
    64  
    65  }
    66  
    67  func TestWebSocketRoundTripper_RoundTripperFails(t *testing.T) {
    68  	// Create fake WebSocket server.
    69  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    70  		// Bad handshake means websocket server will not completely initialize.
    71  		_, err := webSocketServerStreams(req, w)
    72  		require.Error(t, err)
    73  		assert.True(t, strings.Contains(err.Error(), "websocket server finished before becoming ready"))
    74  	}))
    75  	defer websocketServer.Close()
    76  
    77  	// Create the wrapped roundtripper and websocket upgrade roundtripper and call "RoundTrip()".
    78  	websocketLocation, err := url.Parse(websocketServer.URL)
    79  	require.NoError(t, err)
    80  	req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
    81  	require.NoError(t, err)
    82  	rt, _, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
    83  	require.NoError(t, err)
    84  	// Requested subprotocol version 1 is not supported by test websocket server.
    85  	requestedProtocol := remotecommand.StreamProtocolV1Name
    86  	req.Header[wsstream.WebSocketProtocolHeader] = []string{requestedProtocol}
    87  	_, err = rt.RoundTrip(req)
    88  	// Ensure a "bad handshake" error is returned, since requested protocol is not supported.
    89  	require.Error(t, err)
    90  	assert.True(t, strings.Contains(err.Error(), "bad handshake"))
    91  	assert.True(t, httpstream.IsUpgradeFailure(err))
    92  }
    93  
    94  func TestWebSocketRoundTripper_NegotiateCreatesConnection(t *testing.T) {
    95  	// Create fake WebSocket server.
    96  	websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    97  		conns, err := webSocketServerStreams(req, w)
    98  		if err != nil {
    99  			t.Fatalf("error on webSocketServerStreams: %v", err)
   100  		}
   101  		defer conns.conn.Close()
   102  	}))
   103  	defer websocketServer.Close()
   104  
   105  	// Create the websocket roundtripper and call "Negotiate" to create websocket connection.
   106  	websocketLocation, err := url.Parse(websocketServer.URL)
   107  	require.NoError(t, err)
   108  	req, err := http.NewRequestWithContext(context.Background(), "GET", websocketServer.URL, nil)
   109  	require.NoError(t, err)
   110  	rt, wsRt, err := RoundTripperFor(&restclient.Config{Host: websocketLocation.Host})
   111  	require.NoError(t, err)
   112  	requestedProtocol := remotecommand.StreamProtocolV5Name
   113  	conn, err := Negotiate(rt, wsRt, req, requestedProtocol)
   114  	require.NoError(t, err)
   115  	// Compare the expected negotiated subprotocol with the actual subprotocol.
   116  	actualProtocol := conn.Subprotocol()
   117  	assert.Equal(t, requestedProtocol, actualProtocol)
   118  }
   119  
   120  // websocketStreams contains the WebSocket connection and streams from a server.
   121  type websocketStreams struct {
   122  	conn io.Closer
   123  }
   124  
   125  func webSocketServerStreams(req *http.Request, w http.ResponseWriter) (*websocketStreams, error) {
   126  	conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
   127  		remotecommand.StreamProtocolV5Name: {
   128  			Binary:   true,
   129  			Channels: []wsstream.ChannelType{},
   130  		},
   131  	})
   132  	conn.SetIdleTimeout(4 * time.Hour)
   133  	// Opening the connection responds to WebSocket client, negotiating
   134  	// the WebSocket upgrade connection and the subprotocol.
   135  	_, _, err := conn.Open(w, req)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	return &websocketStreams{conn: conn}, nil
   140  }
   141  

View as plain text