...

Source file src/k8s.io/kubernetes/pkg/client/tests/portfoward_test.go

Documentation: k8s.io/kubernetes/pkg/client/tests

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package tests
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"io"
    24  	"net"
    25  	"net/http"
    26  	"net/http/httptest"
    27  	"net/url"
    28  	"os"
    29  	"strings"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  
    34  	"k8s.io/apimachinery/pkg/types"
    35  	restclient "k8s.io/client-go/rest"
    36  	. "k8s.io/client-go/tools/portforward"
    37  	"k8s.io/client-go/transport/spdy"
    38  	"k8s.io/kubelet/pkg/cri/streaming/portforward"
    39  )
    40  
    41  // fakePortForwarder simulates port forwarding for testing. It implements
    42  // portforward.PortForwarder.
    43  type fakePortForwarder struct {
    44  	lock sync.Mutex
    45  	// stores data expected from the stream per port
    46  	expected map[int32]string
    47  	// stores data received from the stream per port
    48  	received map[int32]string
    49  	// data to be sent to the stream per port
    50  	send map[int32]string
    51  }
    52  
    53  var _ portforward.PortForwarder = &fakePortForwarder{}
    54  
    55  func (pf *fakePortForwarder) PortForward(_ context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
    56  	defer stream.Close()
    57  
    58  	// read from the client
    59  	received := make([]byte, len(pf.expected[port]))
    60  	n, err := stream.Read(received)
    61  	if err != nil {
    62  		return fmt.Errorf("error reading from client for port %d: %v", port, err)
    63  	}
    64  	if n != len(pf.expected[port]) {
    65  		return fmt.Errorf("unexpected length read from client for port %d: got %d, expected %d. data=%q", port, n, len(pf.expected[port]), string(received))
    66  	}
    67  
    68  	// store the received content
    69  	pf.lock.Lock()
    70  	pf.received[port] = string(received)
    71  	pf.lock.Unlock()
    72  
    73  	// send the hardcoded data to the client
    74  	io.Copy(stream, strings.NewReader(pf.send[port]))
    75  
    76  	return nil
    77  }
    78  
    79  // fakePortForwardServer creates an HTTP server that can handle port forwarding
    80  // requests.
    81  func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[int32]string) http.HandlerFunc {
    82  	return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    83  		pf := &fakePortForwarder{
    84  			expected: expectedFromClient,
    85  			received: make(map[int32]string),
    86  			send:     serverSends,
    87  		}
    88  		portforward.ServePortForward(w, req, pf, "pod", "uid", nil, 0, 10*time.Second, portforward.SupportedProtocols)
    89  
    90  		for port, expected := range expectedFromClient {
    91  			actual, ok := pf.received[port]
    92  			if !ok {
    93  				t.Errorf("%s: server didn't receive any data for port %d", testName, port)
    94  				continue
    95  			}
    96  
    97  			if expected != actual {
    98  				t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port)
    99  			}
   100  		}
   101  
   102  		for port, actual := range pf.received {
   103  			if _, ok := expectedFromClient[port]; !ok {
   104  				t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port)
   105  			}
   106  		}
   107  	})
   108  }
   109  
   110  func TestForwardPorts(t *testing.T) {
   111  	tests := map[string]struct {
   112  		ports       []string
   113  		clientSends map[int32]string
   114  		serverSends map[int32]string
   115  	}{
   116  		"forward 1 port with no data either direction": {
   117  			ports: []string{":5000"},
   118  		},
   119  		"forward 2 ports with bidirectional data": {
   120  			ports: []string{":5001", ":6000"},
   121  			clientSends: map[int32]string{
   122  				5001: "abcd",
   123  				6000: "ghij",
   124  			},
   125  			serverSends: map[int32]string{
   126  				5001: "1234",
   127  				6000: "5678",
   128  			},
   129  		},
   130  	}
   131  
   132  	for testName, test := range tests {
   133  		t.Run(testName, func(t *testing.T) {
   134  			server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
   135  			defer server.Close()
   136  
   137  			transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
   138  			if err != nil {
   139  				t.Fatal(err)
   140  			}
   141  			url, _ := url.Parse(server.URL)
   142  			dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
   143  
   144  			stopChan := make(chan struct{}, 1)
   145  			readyChan := make(chan struct{})
   146  
   147  			pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr)
   148  			if err != nil {
   149  				t.Fatalf("%s: unexpected error calling New: %v", testName, err)
   150  			}
   151  
   152  			doneChan := make(chan error)
   153  			go func() {
   154  				doneChan <- pf.ForwardPorts()
   155  			}()
   156  			<-pf.Ready
   157  
   158  			forwardedPorts, err := pf.GetPorts()
   159  			if err != nil {
   160  				t.Fatal(err)
   161  			}
   162  
   163  			remoteToLocalMap := map[int32]int32{}
   164  			for _, forwardedPort := range forwardedPorts {
   165  				remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local)
   166  			}
   167  
   168  			clientSend := func(port int32, data string) error {
   169  				clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port]))
   170  				if err != nil {
   171  					return fmt.Errorf("%s: error dialing %d: %s", testName, port, err)
   172  
   173  				}
   174  				defer clientConn.Close()
   175  
   176  				n, err := clientConn.Write([]byte(data))
   177  				if err != nil && err != io.EOF {
   178  					return fmt.Errorf("%s: Error sending data '%s': %s", testName, data, err)
   179  				}
   180  				if n == 0 {
   181  					return fmt.Errorf("%s: unexpected write of 0 bytes", testName)
   182  				}
   183  				b := make([]byte, 4)
   184  				_, err = clientConn.Read(b)
   185  				if err != nil && err != io.EOF {
   186  					return fmt.Errorf("%s: Error reading data: %s", testName, err)
   187  				}
   188  				if !bytes.Equal([]byte(test.serverSends[port]), b) {
   189  					return fmt.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
   190  				}
   191  				return nil
   192  			}
   193  			for port, data := range test.clientSends {
   194  				if err := clientSend(port, data); err != nil {
   195  					t.Error(err)
   196  				}
   197  			}
   198  			// tell r.ForwardPorts to stop
   199  			close(stopChan)
   200  
   201  			// wait for r.ForwardPorts to actually return
   202  			err = <-doneChan
   203  			if err != nil {
   204  				t.Errorf("%s: unexpected error: %s", testName, err)
   205  			}
   206  		})
   207  	}
   208  
   209  }
   210  
   211  func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) {
   212  	server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil))
   213  	defer server.Close()
   214  
   215  	transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
   216  	if err != nil {
   217  		t.Fatal(err)
   218  	}
   219  	url, _ := url.Parse(server.URL)
   220  	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
   221  
   222  	stopChan1 := make(chan struct{}, 1)
   223  	defer close(stopChan1)
   224  	readyChan1 := make(chan struct{})
   225  
   226  	pf1, err := New(dialer, []string{":5555"}, stopChan1, readyChan1, os.Stdout, os.Stderr)
   227  	if err != nil {
   228  		t.Fatalf("error creating pf1: %v", err)
   229  	}
   230  	go pf1.ForwardPorts()
   231  	<-pf1.Ready
   232  
   233  	forwardedPorts, err := pf1.GetPorts()
   234  	if err != nil {
   235  		t.Fatal(err)
   236  	}
   237  	if len(forwardedPorts) != 1 {
   238  		t.Fatalf("expected 1 forwarded port, got %#v", forwardedPorts)
   239  	}
   240  	duplicateSpec := fmt.Sprintf("%d:%d", forwardedPorts[0].Local, forwardedPorts[0].Remote)
   241  
   242  	stopChan2 := make(chan struct{}, 1)
   243  	readyChan2 := make(chan struct{})
   244  	pf2, err := New(dialer, []string{duplicateSpec}, stopChan2, readyChan2, os.Stdout, os.Stderr)
   245  	if err != nil {
   246  		t.Fatalf("error creating pf2: %v", err)
   247  	}
   248  	if err := pf2.ForwardPorts(); err == nil {
   249  		t.Fatal("expected non-nil error for pf2.ForwardPorts")
   250  	}
   251  }
   252  

View as plain text