...

Source file src/k8s.io/client-go/tools/portforward/tunneling_connection_test.go

Documentation: k8s.io/client-go/tools/portforward

     1  /*
     2  Copyright 2024 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package portforward
    18  
    19  import (
    20  	"io"
    21  	"net"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"net/url"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  
    29  	gwebsocket "github.com/gorilla/websocket"
    30  	"github.com/stretchr/testify/assert"
    31  	"github.com/stretchr/testify/require"
    32  
    33  	"k8s.io/apimachinery/pkg/util/httpstream"
    34  	"k8s.io/apimachinery/pkg/util/httpstream/spdy"
    35  	constants "k8s.io/apimachinery/pkg/util/portforward"
    36  	"k8s.io/apimachinery/pkg/util/wait"
    37  	"k8s.io/client-go/rest"
    38  	"k8s.io/client-go/transport/websocket"
    39  )
    40  
    41  func TestTunnelingConnection_ReadWriteClose(t *testing.T) {
    42  	// Stream channel that will receive streams created on upstream SPDY server.
    43  	streamChan := make(chan httpstream.Stream)
    44  	defer close(streamChan)
    45  	stopServerChan := make(chan struct{})
    46  	defer close(stopServerChan)
    47  	// Create tunneling connection server endpoint with fake upstream SPDY server.
    48  	tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    49  		var upgrader = gwebsocket.Upgrader{
    50  			CheckOrigin:  func(r *http.Request) bool { return true },
    51  			Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
    52  		}
    53  		conn, err := upgrader.Upgrade(w, req, nil)
    54  		require.NoError(t, err)
    55  		defer conn.Close() //nolint:errcheck
    56  		require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
    57  		tunnelingConn := NewTunnelingConnection("server", conn)
    58  		spdyConn, err := spdy.NewServerConnection(tunnelingConn, justQueueStream(streamChan))
    59  		require.NoError(t, err)
    60  		defer spdyConn.Close() //nolint:errcheck
    61  		<-stopServerChan
    62  	}))
    63  	defer tunnelingServer.Close()
    64  	// Dial the client tunneling connection to the tunneling server.
    65  	url, err := url.Parse(tunnelingServer.URL)
    66  	require.NoError(t, err)
    67  	dialer, err := NewSPDYOverWebsocketDialer(url, &rest.Config{Host: url.Host})
    68  	require.NoError(t, err)
    69  	spdyClient, protocol, err := dialer.Dial(constants.PortForwardV1Name)
    70  	require.NoError(t, err)
    71  	assert.Equal(t, constants.PortForwardV1Name, protocol)
    72  	defer spdyClient.Close() //nolint:errcheck
    73  	// Create a SPDY client stream, which will queue a SPDY server stream
    74  	// on the stream creation channel. Send data on the client stream
    75  	// reading off the SPDY server stream, and validating it was tunneled.
    76  	expected := "This is a test tunneling SPDY data through websockets."
    77  	var actual []byte
    78  	go func() {
    79  		clientStream, err := spdyClient.CreateStream(http.Header{})
    80  		require.NoError(t, err)
    81  		_, err = io.Copy(clientStream, strings.NewReader(expected))
    82  		require.NoError(t, err)
    83  		clientStream.Close() //nolint:errcheck
    84  	}()
    85  	select {
    86  	case serverStream := <-streamChan:
    87  		actual, err = io.ReadAll(serverStream)
    88  		require.NoError(t, err)
    89  		defer serverStream.Close() //nolint:errcheck
    90  	case <-time.After(wait.ForeverTestTimeout):
    91  		t.Fatalf("timeout waiting for spdy stream to arrive on channel.")
    92  	}
    93  	assert.Equal(t, expected, string(actual), "error validating tunneled string")
    94  }
    95  
    96  func TestTunnelingConnection_LocalRemoteAddress(t *testing.T) {
    97  	stopServerChan := make(chan struct{})
    98  	defer close(stopServerChan)
    99  	tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   100  		var upgrader = gwebsocket.Upgrader{
   101  			CheckOrigin:  func(r *http.Request) bool { return true },
   102  			Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
   103  		}
   104  		conn, err := upgrader.Upgrade(w, req, nil)
   105  		require.NoError(t, err)
   106  		defer conn.Close() //nolint:errcheck
   107  		require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
   108  		<-stopServerChan
   109  	}))
   110  	defer tunnelingServer.Close()
   111  	// Create the client side tunneling connection.
   112  	url, err := url.Parse(tunnelingServer.URL)
   113  	require.NoError(t, err)
   114  	tConn, err := dialForTunnelingConnection(url)
   115  	require.NoError(t, err, "error creating client tunneling connection")
   116  	defer tConn.Close() //nolint:errcheck
   117  	// Validate "LocalAddr()" and "RemoteAddr()"
   118  	localAddr := tConn.LocalAddr()
   119  	remoteAddr := tConn.RemoteAddr()
   120  	assert.Equal(t, "tcp", localAddr.Network(), "tunneling connection must be TCP")
   121  	assert.Equal(t, "tcp", remoteAddr.Network(), "tunneling connection must be TCP")
   122  	_, err = net.ResolveTCPAddr("tcp", localAddr.String())
   123  	assert.NoError(t, err, "tunneling connection local addr should parse")
   124  	_, err = net.ResolveTCPAddr("tcp", remoteAddr.String())
   125  	assert.NoError(t, err, "tunneling connection remote addr should parse")
   126  }
   127  
   128  func TestTunnelingConnection_ReadWriteDeadlines(t *testing.T) {
   129  	stopServerChan := make(chan struct{})
   130  	defer close(stopServerChan)
   131  	tunnelingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   132  		var upgrader = gwebsocket.Upgrader{
   133  			CheckOrigin:  func(r *http.Request) bool { return true },
   134  			Subprotocols: []string{constants.WebsocketsSPDYTunnelingPortForwardV1},
   135  		}
   136  		conn, err := upgrader.Upgrade(w, req, nil)
   137  		require.NoError(t, err)
   138  		defer conn.Close() //nolint:errcheck
   139  		require.Equal(t, constants.WebsocketsSPDYTunnelingPortForwardV1, conn.Subprotocol())
   140  		<-stopServerChan
   141  	}))
   142  	defer tunnelingServer.Close()
   143  	// Create the client side tunneling connection.
   144  	url, err := url.Parse(tunnelingServer.URL)
   145  	require.NoError(t, err)
   146  	tConn, err := dialForTunnelingConnection(url)
   147  	require.NoError(t, err, "error creating client tunneling connection")
   148  	defer tConn.Close() //nolint:errcheck
   149  	// Validate the read and write deadlines.
   150  	err = tConn.SetReadDeadline(time.Time{})
   151  	assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
   152  	err = tConn.SetWriteDeadline(time.Time{})
   153  	assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
   154  	err = tConn.SetDeadline(time.Time{})
   155  	assert.NoError(t, err, "setting zero deadline should always succeed; turns off deadline")
   156  	err = tConn.SetReadDeadline(time.Now().AddDate(10, 0, 0))
   157  	assert.NoError(t, err, "setting deadline 10 year from now succeeds")
   158  	err = tConn.SetWriteDeadline(time.Now().AddDate(10, 0, 0))
   159  	assert.NoError(t, err, "setting deadline 10 year from now succeeds")
   160  	err = tConn.SetDeadline(time.Now().AddDate(10, 0, 0))
   161  	assert.NoError(t, err, "setting deadline 10 year from now succeeds")
   162  }
   163  
   164  // dialForTunnelingConnection upgrades a request at the passed "url", creating
   165  // a websocket connection. Returns the TunnelingConnection injected with the
   166  // websocket connection or an error if one occurs.
   167  func dialForTunnelingConnection(url *url.URL) (*TunnelingConnection, error) {
   168  	req, err := http.NewRequest("GET", url.String(), nil)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	// Tunneling must initiate a websocket upgrade connection, using tunneling portforward protocol.
   173  	tunnelingProtocols := []string{constants.WebsocketsSPDYTunnelingPortForwardV1}
   174  	transport, holder, err := websocket.RoundTripperFor(&rest.Config{Host: url.Host})
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  	conn, err := websocket.Negotiate(transport, holder, req, tunnelingProtocols...)
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	return NewTunnelingConnection("client", conn), nil
   183  }
   184  
   185  func justQueueStream(streams chan httpstream.Stream) func(httpstream.Stream, <-chan struct{}) error {
   186  	return func(stream httpstream.Stream, replySent <-chan struct{}) error {
   187  		streams <- stream
   188  		return nil
   189  	}
   190  }
   191  

View as plain text