...

Source file src/k8s.io/kubernetes/pkg/kubelet/server/server_websocket_test.go

Documentation: k8s.io/kubernetes/pkg/kubelet/server

     1  /*
     2  Copyright 2016 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package server
    18  
    19  import (
    20  	"encoding/binary"
    21  	"fmt"
    22  	"io"
    23  	"strconv"
    24  	"sync"
    25  	"testing"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"golang.org/x/net/websocket"
    30  
    31  	"k8s.io/apimachinery/pkg/types"
    32  	"k8s.io/kubelet/pkg/cri/streaming/portforward"
    33  )
    34  
    35  const (
    36  	dataChannel = iota
    37  	errorChannel
    38  )
    39  
    40  func TestServeWSPortForward(t *testing.T) {
    41  	tests := map[string]struct {
    42  		port          string
    43  		uid           bool
    44  		clientData    string
    45  		containerData string
    46  		shouldError   bool
    47  	}{
    48  		"no port":                       {port: "", shouldError: true},
    49  		"none number port":              {port: "abc", shouldError: true},
    50  		"negative port":                 {port: "-1", shouldError: true},
    51  		"too large port":                {port: "65536", shouldError: true},
    52  		"0 port":                        {port: "0", shouldError: true},
    53  		"min port":                      {port: "1", shouldError: false},
    54  		"normal port":                   {port: "8000", shouldError: false},
    55  		"normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
    56  		"max port":                      {port: "65535", shouldError: false},
    57  		"normal port with uid":          {port: "8000", uid: true, shouldError: false},
    58  	}
    59  
    60  	podNamespace := "other"
    61  	podName := "foo"
    62  
    63  	for desc := range tests {
    64  		test := tests[desc]
    65  		t.Run(desc, func(t *testing.T) {
    66  			ss, err := newTestStreamingServer(0)
    67  			require.NoError(t, err)
    68  			defer ss.testHTTPServer.Close()
    69  			fw := newServerTestWithDebug(true, ss)
    70  			defer fw.testHTTPServer.Close()
    71  
    72  			portForwardFuncDone := make(chan struct{})
    73  
    74  			fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
    75  				assert.Equal(t, podName, name, "pod name")
    76  				assert.Equal(t, podNamespace, namespace, "pod namespace")
    77  				if test.uid {
    78  					assert.Equal(t, testUID, string(uid), "uid")
    79  				}
    80  			}
    81  
    82  			ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
    83  				defer close(portForwardFuncDone)
    84  				assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
    85  				// The port should be valid if it reaches here.
    86  				testPort, err := strconv.ParseInt(test.port, 10, 32)
    87  				require.NoError(t, err, "parse port")
    88  				assert.Equal(t, int32(testPort), port, "port")
    89  
    90  				if test.clientData != "" {
    91  					fromClient := make([]byte, 32)
    92  					n, err := stream.Read(fromClient)
    93  					assert.NoError(t, err, "reading client data")
    94  					assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data")
    95  				}
    96  
    97  				if test.containerData != "" {
    98  					_, err := stream.Write([]byte(test.containerData))
    99  					assert.NoError(t, err, "writing container data")
   100  				}
   101  
   102  				return nil
   103  			}
   104  
   105  			var url string
   106  			if test.uid {
   107  				url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, testUID, test.port)
   108  			} else {
   109  				url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port)
   110  			}
   111  
   112  			ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
   113  			assert.Equal(t, test.shouldError, err != nil, "websocket dial")
   114  			if test.shouldError {
   115  				return
   116  			}
   117  			defer ws.Close()
   118  
   119  			p, err := strconv.ParseUint(test.port, 10, 16)
   120  			require.NoError(t, err, "parse port")
   121  			p16 := uint16(p)
   122  
   123  			channel, data, err := wsRead(ws)
   124  			require.NoError(t, err, "read")
   125  			assert.Equal(t, dataChannel, int(channel), "channel")
   126  			assert.Len(t, data, binary.Size(p16), "data size")
   127  			assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
   128  
   129  			channel, data, err = wsRead(ws)
   130  			assert.NoError(t, err, "read")
   131  			assert.Equal(t, errorChannel, int(channel), "channel")
   132  			assert.Len(t, data, binary.Size(p16), "data size")
   133  			assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
   134  
   135  			if test.clientData != "" {
   136  				println("writing the client data")
   137  				err := wsWrite(ws, dataChannel, []byte(test.clientData))
   138  				assert.NoError(t, err, "writing client data")
   139  			}
   140  
   141  			if test.containerData != "" {
   142  				_, data, err = wsRead(ws)
   143  				assert.NoError(t, err, "reading container data")
   144  				assert.Equal(t, test.containerData, string(data), "container data")
   145  			}
   146  
   147  			<-portForwardFuncDone
   148  		})
   149  	}
   150  }
   151  
   152  func TestServeWSMultiplePortForward(t *testing.T) {
   153  	portsText := []string{"7000,8000", "9000"}
   154  	ports := []uint16{7000, 8000, 9000}
   155  	podNamespace := "other"
   156  	podName := "foo"
   157  
   158  	ss, err := newTestStreamingServer(0)
   159  	require.NoError(t, err)
   160  	defer ss.testHTTPServer.Close()
   161  	fw := newServerTestWithDebug(true, ss)
   162  	defer fw.testHTTPServer.Close()
   163  
   164  	portForwardWG := sync.WaitGroup{}
   165  	portForwardWG.Add(len(ports))
   166  
   167  	portsMutex := sync.Mutex{}
   168  	portsForwarded := map[int32]struct{}{}
   169  
   170  	fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
   171  		assert.Equal(t, podName, name, "pod name")
   172  		assert.Equal(t, podNamespace, namespace, "pod namespace")
   173  	}
   174  
   175  	ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
   176  		defer portForwardWG.Done()
   177  		assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
   178  
   179  		portsMutex.Lock()
   180  		portsForwarded[port] = struct{}{}
   181  		portsMutex.Unlock()
   182  
   183  		fromClient := make([]byte, 32)
   184  		n, err := stream.Read(fromClient)
   185  		assert.NoError(t, err, "reading client data")
   186  		assert.Equal(t, fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]), "client data")
   187  
   188  		_, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port)))
   189  		assert.NoError(t, err, "writing container data")
   190  
   191  		return nil
   192  	}
   193  
   194  	url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName)
   195  	for _, port := range portsText {
   196  		url = url + fmt.Sprintf("port=%s&", port)
   197  	}
   198  
   199  	ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
   200  	require.NoError(t, err, "websocket dial")
   201  
   202  	defer ws.Close()
   203  
   204  	for i, port := range ports {
   205  		channel, data, err := wsRead(ws)
   206  		assert.NoError(t, err, "port %d read", port)
   207  		assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
   208  		assert.Len(t, data, binary.Size(port), "port %d data size", port)
   209  		assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
   210  
   211  		channel, data, err = wsRead(ws)
   212  		assert.NoError(t, err, "port %d read", port)
   213  		assert.Equal(t, i*2+errorChannel, int(channel), "port %d channel", port)
   214  		assert.Len(t, data, binary.Size(port), "port %d data size", port)
   215  		assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
   216  	}
   217  
   218  	for i, port := range ports {
   219  		t.Logf("port %d writing the client data", port)
   220  		err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port)))
   221  		assert.NoError(t, err, "port %d write client data", port)
   222  
   223  		channel, data, err := wsRead(ws)
   224  		assert.NoError(t, err, "port %d read container data", port)
   225  		assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
   226  		assert.Equal(t, fmt.Sprintf("container data on port %d", port), string(data), "port %d container data", port)
   227  	}
   228  
   229  	portForwardWG.Wait()
   230  
   231  	portsMutex.Lock()
   232  	defer portsMutex.Unlock()
   233  	assert.Len(t, portsForwarded, len(ports), "all ports forwarded")
   234  }
   235  
   236  func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
   237  	frame := make([]byte, len(data)+1)
   238  	frame[0] = channel
   239  	copy(frame[1:], data)
   240  	err := websocket.Message.Send(conn, frame)
   241  	return err
   242  }
   243  
   244  func wsRead(conn *websocket.Conn) (byte, []byte, error) {
   245  	for {
   246  		var data []byte
   247  		err := websocket.Message.Receive(conn, &data)
   248  		if err != nil {
   249  			return 0, nil, err
   250  		}
   251  
   252  		if len(data) == 0 {
   253  			continue
   254  		}
   255  
   256  		channel := data[0]
   257  		data = data[1:]
   258  
   259  		return channel, data, err
   260  	}
   261  }
   262  

View as plain text