...

Source file src/k8s.io/client-go/tools/portforward/portforward_test.go

Documentation: k8s.io/client-go/tools/portforward

     1  /*
     2  Copyright 2015 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package portforward
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"net"
    23  	"net/http"
    24  	"os"
    25  	"reflect"
    26  	"sort"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"github.com/stretchr/testify/assert"
    32  
    33  	v1 "k8s.io/api/core/v1"
    34  	"k8s.io/apimachinery/pkg/util/httpstream"
    35  )
    36  
    37  type fakeDialer struct {
    38  	dialed             bool
    39  	conn               httpstream.Connection
    40  	err                error
    41  	negotiatedProtocol string
    42  }
    43  
    44  func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
    45  	d.dialed = true
    46  	return d.conn, d.negotiatedProtocol, d.err
    47  }
    48  
    49  type fakeConnection struct {
    50  	closed      bool
    51  	closeChan   chan bool
    52  	dataStream  *fakeStream
    53  	errorStream *fakeStream
    54  	streamCount int
    55  }
    56  
    57  func newFakeConnection() *fakeConnection {
    58  	return &fakeConnection{
    59  		closeChan:   make(chan bool),
    60  		dataStream:  &fakeStream{},
    61  		errorStream: &fakeStream{},
    62  	}
    63  }
    64  
    65  func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
    66  	switch headers.Get(v1.StreamType) {
    67  	case v1.StreamTypeData:
    68  		c.streamCount++
    69  		return c.dataStream, nil
    70  	case v1.StreamTypeError:
    71  		c.streamCount++
    72  		return c.errorStream, nil
    73  	default:
    74  		return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
    75  	}
    76  }
    77  
    78  func (c *fakeConnection) Close() error {
    79  	if !c.closed {
    80  		c.closed = true
    81  		close(c.closeChan)
    82  	}
    83  	return nil
    84  }
    85  
    86  func (c *fakeConnection) CloseChan() <-chan bool {
    87  	return c.closeChan
    88  }
    89  
    90  func (c *fakeConnection) RemoveStreams(streams ...httpstream.Stream) {
    91  	for range streams {
    92  		c.streamCount--
    93  	}
    94  }
    95  
    96  func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
    97  	// no-op
    98  }
    99  
   100  type fakeListener struct {
   101  	net.Listener
   102  	closeChan chan bool
   103  }
   104  
   105  func newFakeListener() fakeListener {
   106  	return fakeListener{
   107  		closeChan: make(chan bool),
   108  	}
   109  }
   110  
   111  func (l *fakeListener) Accept() (net.Conn, error) {
   112  	select {
   113  	case <-l.closeChan:
   114  		return nil, fmt.Errorf("listener closed")
   115  	}
   116  }
   117  
   118  func (l *fakeListener) Close() error {
   119  	close(l.closeChan)
   120  	return nil
   121  }
   122  
   123  func (l *fakeListener) Addr() net.Addr {
   124  	return fakeAddr{}
   125  }
   126  
   127  type fakeAddr struct{}
   128  
   129  func (fakeAddr) Network() string { return "fake" }
   130  func (fakeAddr) String() string  { return "fake" }
   131  
   132  type fakeStream struct {
   133  	headers   http.Header
   134  	readFunc  func(p []byte) (int, error)
   135  	writeFunc func(p []byte) (int, error)
   136  }
   137  
   138  func (s *fakeStream) Read(p []byte) (n int, err error)  { return s.readFunc(p) }
   139  func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) }
   140  func (*fakeStream) Close() error                        { return nil }
   141  func (*fakeStream) Reset() error                        { return nil }
   142  func (s *fakeStream) Headers() http.Header              { return s.headers }
   143  func (*fakeStream) Identifier() uint32                  { return 0 }
   144  
   145  type fakeConn struct {
   146  	sendBuffer    *bytes.Buffer
   147  	receiveBuffer *bytes.Buffer
   148  }
   149  
   150  func (f fakeConn) Read(p []byte) (int, error)       { return f.sendBuffer.Read(p) }
   151  func (f fakeConn) Write(p []byte) (int, error)      { return f.receiveBuffer.Write(p) }
   152  func (fakeConn) Close() error                       { return nil }
   153  func (fakeConn) LocalAddr() net.Addr                { return nil }
   154  func (fakeConn) RemoteAddr() net.Addr               { return nil }
   155  func (fakeConn) SetDeadline(t time.Time) error      { return nil }
   156  func (fakeConn) SetReadDeadline(t time.Time) error  { return nil }
   157  func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
   158  
   159  func TestParsePortsAndNew(t *testing.T) {
   160  	tests := []struct {
   161  		input                   []string
   162  		addresses               []string
   163  		expectedPorts           []ForwardedPort
   164  		expectedAddresses       []listenAddress
   165  		expectPortParseError    bool
   166  		expectAddressParseError bool
   167  		expectNewError          bool
   168  	}{
   169  		{input: []string{}, expectNewError: true},
   170  		{input: []string{"a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   171  		{input: []string{":a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   172  		{input: []string{"-1"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   173  		{input: []string{"65536"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   174  		{input: []string{"0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   175  		{input: []string{"0:0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   176  		{input: []string{"a:5000"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   177  		{input: []string{"5000:a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
   178  		{input: []string{"5000:5000"}, addresses: []string{"127.0.0.257"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
   179  		{input: []string{"5000:5000"}, addresses: []string{"::g"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
   180  		{input: []string{"5000:5000"}, addresses: []string{"domain.invalid"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
   181  		{
   182  			input:     []string{"5000:5000"},
   183  			addresses: []string{"localhost"},
   184  			expectedPorts: []ForwardedPort{
   185  				{5000, 5000},
   186  			},
   187  			expectedAddresses: []listenAddress{
   188  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "all"},
   189  				{protocol: "tcp6", address: "::1", failureMode: "all"},
   190  			},
   191  		},
   192  		{
   193  			input:     []string{"5000:5000"},
   194  			addresses: []string{"localhost", "127.0.0.1"},
   195  			expectedPorts: []ForwardedPort{
   196  				{5000, 5000},
   197  			},
   198  			expectedAddresses: []listenAddress{
   199  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   200  				{protocol: "tcp6", address: "::1", failureMode: "all"},
   201  			},
   202  		},
   203  		{
   204  			input:     []string{"5000:5000"},
   205  			addresses: []string{"localhost", "::1"},
   206  			expectedPorts: []ForwardedPort{
   207  				{5000, 5000},
   208  			},
   209  			expectedAddresses: []listenAddress{
   210  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "all"},
   211  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   212  			},
   213  		},
   214  		{
   215  			input:     []string{"5000:5000"},
   216  			addresses: []string{"localhost", "127.0.0.1", "::1"},
   217  			expectedPorts: []ForwardedPort{
   218  				{5000, 5000},
   219  			},
   220  			expectedAddresses: []listenAddress{
   221  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   222  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   223  			},
   224  		},
   225  		{
   226  			input:     []string{"5000:5000"},
   227  			addresses: []string{"localhost", "127.0.0.1", "10.10.10.1"},
   228  			expectedPorts: []ForwardedPort{
   229  				{5000, 5000},
   230  			},
   231  			expectedAddresses: []listenAddress{
   232  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   233  				{protocol: "tcp6", address: "::1", failureMode: "all"},
   234  				{protocol: "tcp4", address: "10.10.10.1", failureMode: "any"},
   235  			},
   236  		},
   237  		{
   238  			input:     []string{"5000:5000"},
   239  			addresses: []string{"127.0.0.1", "::1", "localhost"},
   240  			expectedPorts: []ForwardedPort{
   241  				{5000, 5000},
   242  			},
   243  			expectedAddresses: []listenAddress{
   244  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   245  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   246  			},
   247  		},
   248  		{
   249  			input:     []string{"5000:5000"},
   250  			addresses: []string{"10.0.0.1", "127.0.0.1"},
   251  			expectedPorts: []ForwardedPort{
   252  				{5000, 5000},
   253  			},
   254  			expectedAddresses: []listenAddress{
   255  				{protocol: "tcp4", address: "10.0.0.1", failureMode: "any"},
   256  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   257  			},
   258  		},
   259  		{
   260  			input:     []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"},
   261  			addresses: []string{"127.0.0.1", "::1"},
   262  			expectedPorts: []ForwardedPort{
   263  				{5000, 5000},
   264  				{5000, 5000},
   265  				{8888, 5000},
   266  				{5000, 8888},
   267  				{0, 5000},
   268  				{0, 5000},
   269  			},
   270  			expectedAddresses: []listenAddress{
   271  				{protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
   272  				{protocol: "tcp6", address: "::1", failureMode: "any"},
   273  			},
   274  		},
   275  	}
   276  
   277  	for i, test := range tests {
   278  		parsedPorts, err := parsePorts(test.input)
   279  		haveError := err != nil
   280  		if e, a := test.expectPortParseError, haveError; e != a {
   281  			t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err)
   282  		}
   283  
   284  		// default to localhost
   285  		if len(test.addresses) == 0 && len(test.expectedAddresses) == 0 {
   286  			test.addresses = []string{"localhost"}
   287  			test.expectedAddresses = []listenAddress{{protocol: "tcp4", address: "127.0.0.1"}, {protocol: "tcp6", address: "::1"}}
   288  		}
   289  		// assert address parser
   290  		parsedAddresses, err := parseAddresses(test.addresses)
   291  		haveError = err != nil
   292  		if e, a := test.expectAddressParseError, haveError; e != a {
   293  			t.Fatalf("%d: parseAddresses: error expected=%t, got %t: %s", i, e, a, err)
   294  		}
   295  
   296  		dialer := &fakeDialer{}
   297  		expectedStopChan := make(chan struct{})
   298  		readyChan := make(chan struct{})
   299  
   300  		var pf *PortForwarder
   301  		if len(test.addresses) > 0 {
   302  			pf, err = NewOnAddresses(dialer, test.addresses, test.input, expectedStopChan, readyChan, os.Stdout, os.Stderr)
   303  		} else {
   304  			pf, err = New(dialer, test.input, expectedStopChan, readyChan, os.Stdout, os.Stderr)
   305  		}
   306  		haveError = err != nil
   307  		if e, a := test.expectNewError, haveError; e != a {
   308  			t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err)
   309  		}
   310  
   311  		if test.expectPortParseError || test.expectAddressParseError || test.expectNewError {
   312  			continue
   313  		}
   314  
   315  		sort.Slice(test.expectedAddresses, func(i, j int) bool { return test.expectedAddresses[i].address < test.expectedAddresses[j].address })
   316  		sort.Slice(parsedAddresses, func(i, j int) bool { return parsedAddresses[i].address < parsedAddresses[j].address })
   317  
   318  		if !reflect.DeepEqual(test.expectedAddresses, parsedAddresses) {
   319  			t.Fatalf("%d: expectedAddresses: %v, got: %v", i, test.expectedAddresses, parsedAddresses)
   320  		}
   321  
   322  		for pi, expectedPort := range test.expectedPorts {
   323  			if e, a := expectedPort.Local, parsedPorts[pi].Local; e != a {
   324  				t.Fatalf("%d: local expected: %d, got: %d", i, e, a)
   325  			}
   326  			if e, a := expectedPort.Remote, parsedPorts[pi].Remote; e != a {
   327  				t.Fatalf("%d: remote expected: %d, got: %d", i, e, a)
   328  			}
   329  		}
   330  
   331  		if dialer.dialed {
   332  			t.Fatalf("%d: expected not dialed", i)
   333  		}
   334  		if _, portErr := pf.GetPorts(); portErr == nil {
   335  			t.Fatalf("%d: GetPorts: error expected but got nil", i)
   336  		}
   337  
   338  		// mock-signal the Ready channel
   339  		close(readyChan)
   340  
   341  		if ports, portErr := pf.GetPorts(); portErr != nil {
   342  			t.Fatalf("%d: GetPorts: unable to retrieve ports: %s", i, portErr)
   343  		} else if !reflect.DeepEqual(test.expectedPorts, ports) {
   344  			t.Fatalf("%d: ports: expected %#v, got %#v", i, test.expectedPorts, ports)
   345  		}
   346  		if e, a := expectedStopChan, pf.stopChan; e != a {
   347  			t.Fatalf("%d: stopChan: expected %#v, got %#v", i, e, a)
   348  		}
   349  		if pf.Ready == nil {
   350  			t.Fatalf("%d: Ready should be non-nil", i)
   351  		}
   352  	}
   353  }
   354  
   355  type GetListenerTestCase struct {
   356  	Hostname                string
   357  	Protocol                string
   358  	ShouldRaiseError        bool
   359  	ExpectedListenerAddress string
   360  }
   361  
   362  func TestGetListener(t *testing.T) {
   363  	var pf PortForwarder
   364  	testCases := []GetListenerTestCase{
   365  		{
   366  			Hostname:                "localhost",
   367  			Protocol:                "tcp4",
   368  			ShouldRaiseError:        false,
   369  			ExpectedListenerAddress: "127.0.0.1",
   370  		},
   371  		{
   372  			Hostname:                "127.0.0.1",
   373  			Protocol:                "tcp4",
   374  			ShouldRaiseError:        false,
   375  			ExpectedListenerAddress: "127.0.0.1",
   376  		},
   377  		{
   378  			Hostname:                "::1",
   379  			Protocol:                "tcp6",
   380  			ShouldRaiseError:        false,
   381  			ExpectedListenerAddress: "::1",
   382  		},
   383  		{
   384  			Hostname:         "::1",
   385  			Protocol:         "tcp4",
   386  			ShouldRaiseError: true,
   387  		},
   388  		{
   389  			Hostname:         "127.0.0.1",
   390  			Protocol:         "tcp6",
   391  			ShouldRaiseError: true,
   392  		},
   393  	}
   394  
   395  	for i, testCase := range testCases {
   396  		forwardedPort := &ForwardedPort{Local: 0, Remote: 12345}
   397  		listener, err := pf.getListener(testCase.Protocol, testCase.Hostname, forwardedPort)
   398  		if err != nil && strings.Contains(err.Error(), "cannot assign requested address") {
   399  			t.Logf("Can't test #%d: %v", i, err)
   400  			continue
   401  		}
   402  		expectedListenerPort := fmt.Sprintf("%d", forwardedPort.Local)
   403  		errorRaised := err != nil
   404  
   405  		if testCase.ShouldRaiseError != errorRaised {
   406  			t.Errorf("Test case #%d failed: Data %v an error has been raised(%t) where it should not (or reciprocally): %v", i, testCase, testCase.ShouldRaiseError, err)
   407  			continue
   408  		}
   409  		if errorRaised {
   410  			continue
   411  		}
   412  
   413  		if listener == nil {
   414  			t.Errorf("Test case #%d did not raise an error but failed in initializing listener", i)
   415  			continue
   416  		}
   417  
   418  		host, port, _ := net.SplitHostPort(listener.Addr().String())
   419  		t.Logf("Asked a %s forward for: %s:0, got listener %s:%s, expected: %s", testCase.Protocol, testCase.Hostname, host, port, expectedListenerPort)
   420  		if host != testCase.ExpectedListenerAddress {
   421  			t.Errorf("Test case #%d failed: Listener does not listen on expected address: asked '%v' got '%v'", i, testCase.ExpectedListenerAddress, host)
   422  		}
   423  		if port != expectedListenerPort {
   424  			t.Errorf("Test case #%d failed: Listener does not listen on expected port: asked %v got %v", i, expectedListenerPort, port)
   425  
   426  		}
   427  		listener.Close()
   428  	}
   429  }
   430  
   431  func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
   432  	dialer := &fakeDialer{
   433  		conn:               newFakeConnection(),
   434  		negotiatedProtocol: PortForwardProtocolV1Name,
   435  	}
   436  
   437  	stopChan := make(chan struct{})
   438  	readyChan := make(chan struct{})
   439  	errChan := make(chan error)
   440  
   441  	defer func() {
   442  		close(stopChan)
   443  
   444  		forwardErr := <-errChan
   445  		if forwardErr != nil {
   446  			t.Fatalf("ForwardPorts returned error: %s", forwardErr)
   447  		}
   448  	}()
   449  
   450  	pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr)
   451  
   452  	if err != nil {
   453  		t.Fatalf("error while calling New: %s", err)
   454  	}
   455  
   456  	go func() {
   457  		errChan <- pf.ForwardPorts()
   458  		close(errChan)
   459  	}()
   460  
   461  	<-pf.Ready
   462  
   463  	ports, err := pf.GetPorts()
   464  	if err != nil {
   465  		t.Fatalf("Failed to get ports. error: %v", err)
   466  	}
   467  
   468  	if len(ports) != 1 {
   469  		t.Fatalf("expected 1 port, got %d", len(ports))
   470  	}
   471  
   472  	port := ports[0]
   473  	if port.Local == 0 {
   474  		t.Fatalf("local port is 0, expected != 0")
   475  	}
   476  }
   477  
   478  func TestHandleConnection(t *testing.T) {
   479  	out := bytes.NewBufferString("")
   480  
   481  	pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil)
   482  	if err != nil {
   483  		t.Fatalf("error while calling New: %s", err)
   484  	}
   485  
   486  	// Setup fake local connection
   487  	localConnection := &fakeConn{
   488  		sendBuffer:    bytes.NewBufferString("test data from local"),
   489  		receiveBuffer: bytes.NewBufferString(""),
   490  	}
   491  
   492  	// Setup fake remote connection to send data on the data stream after it receives data from the local connection
   493  	remoteDataToSend := bytes.NewBufferString("test data from remote")
   494  	remoteDataReceived := bytes.NewBufferString("")
   495  	remoteErrorToSend := bytes.NewBufferString("")
   496  	blockRemoteSend := make(chan struct{})
   497  	remoteConnection := newFakeConnection()
   498  	remoteConnection.dataStream.readFunc = func(p []byte) (int, error) {
   499  		<-blockRemoteSend // Wait for the expected data to be received before responding
   500  		return remoteDataToSend.Read(p)
   501  	}
   502  	remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) {
   503  		n, err := remoteDataReceived.Write(p)
   504  		if remoteDataReceived.String() == "test data from local" {
   505  			close(blockRemoteSend)
   506  		}
   507  		return n, err
   508  	}
   509  	remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
   510  	pf.streamConn = remoteConnection
   511  
   512  	// Test handleConnection
   513  	pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
   514  	assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
   515  	assert.Equal(t, "test data from local", remoteDataReceived.String())
   516  	assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
   517  	assert.Equal(t, "Handling connection for 1111\n", out.String())
   518  }
   519  
   520  func TestHandleConnectionSendsRemoteError(t *testing.T) {
   521  	out := bytes.NewBufferString("")
   522  	errOut := bytes.NewBufferString("")
   523  
   524  	pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
   525  	if err != nil {
   526  		t.Fatalf("error while calling New: %s", err)
   527  	}
   528  
   529  	// Setup fake local connection
   530  	localConnection := &fakeConn{
   531  		sendBuffer:    bytes.NewBufferString(""),
   532  		receiveBuffer: bytes.NewBufferString(""),
   533  	}
   534  
   535  	// Setup fake remote connection to return an error message on the error stream
   536  	remoteDataToSend := bytes.NewBufferString("")
   537  	remoteDataReceived := bytes.NewBufferString("")
   538  	remoteErrorToSend := bytes.NewBufferString("error")
   539  	remoteConnection := newFakeConnection()
   540  	remoteConnection.dataStream.readFunc = remoteDataToSend.Read
   541  	remoteConnection.dataStream.writeFunc = remoteDataReceived.Write
   542  	remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
   543  	pf.streamConn = remoteConnection
   544  
   545  	// Test handleConnection, using go-routine because it needs to be able to write to unbuffered pf.errorChan
   546  	pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
   547  
   548  	assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
   549  	assert.Equal(t, "", remoteDataReceived.String())
   550  	assert.Equal(t, "", localConnection.receiveBuffer.String())
   551  	assert.Equal(t, "Handling connection for 1111\n", out.String())
   552  }
   553  
   554  func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
   555  	out := bytes.NewBufferString("")
   556  	errOut := bytes.NewBufferString("")
   557  
   558  	pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
   559  	if err != nil {
   560  		t.Fatalf("error while calling New: %s", err)
   561  	}
   562  
   563  	listener := newFakeListener()
   564  
   565  	pf.streamConn = newFakeConnection()
   566  	pf.streamConn.Close()
   567  
   568  	port := ForwardedPort{}
   569  	pf.waitForConnection(&listener, port)
   570  }
   571  
   572  func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
   573  	dialer := &fakeDialer{
   574  		conn:               newFakeConnection(),
   575  		negotiatedProtocol: PortForwardProtocolV1Name,
   576  	}
   577  
   578  	stopChan := make(chan struct{})
   579  	readyChan := make(chan struct{})
   580  	errChan := make(chan error)
   581  
   582  	pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr)
   583  	if err != nil {
   584  		t.Fatalf("failed to create new PortForwarder: %s", err)
   585  	}
   586  
   587  	go func() {
   588  		errChan <- pf.ForwardPorts()
   589  	}()
   590  
   591  	<-pf.Ready
   592  
   593  	// Simulate lost pod connection by closing streamConn, which should result in pf.ForwardPorts() returning an error.
   594  	pf.streamConn.Close()
   595  
   596  	err = <-errChan
   597  	if err == nil {
   598  		t.Fatalf("unexpected non-error from pf.ForwardPorts()")
   599  	} else if err != ErrLostConnectionToPod {
   600  		t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err)
   601  	}
   602  }
   603  
   604  func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
   605  	dialer := &fakeDialer{
   606  		conn:               newFakeConnection(),
   607  		negotiatedProtocol: PortForwardProtocolV1Name,
   608  	}
   609  
   610  	stopChan := make(chan struct{})
   611  	readyChan := make(chan struct{})
   612  	errChan := make(chan error)
   613  
   614  	pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr)
   615  	if err != nil {
   616  		t.Fatalf("failed to create new PortForwarder: %s", err)
   617  	}
   618  
   619  	go func() {
   620  		errChan <- pf.ForwardPorts()
   621  	}()
   622  
   623  	<-pf.Ready
   624  
   625  	// Closing (or sending to) stopChan indicates a stop request by the caller, which should result in pf.ForwardPorts()
   626  	// returning nil.
   627  	close(stopChan)
   628  
   629  	err = <-errChan
   630  	if err != nil {
   631  		t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err)
   632  	}
   633  }
   634  

View as plain text