1
16
17 package portforward
18
19 import (
20 "bytes"
21 "fmt"
22 "net"
23 "net/http"
24 "os"
25 "reflect"
26 "sort"
27 "strings"
28 "testing"
29 "time"
30
31 "github.com/stretchr/testify/assert"
32
33 v1 "k8s.io/api/core/v1"
34 "k8s.io/apimachinery/pkg/util/httpstream"
35 )
36
37 type fakeDialer struct {
38 dialed bool
39 conn httpstream.Connection
40 err error
41 negotiatedProtocol string
42 }
43
44 func (d *fakeDialer) Dial(protocols ...string) (httpstream.Connection, string, error) {
45 d.dialed = true
46 return d.conn, d.negotiatedProtocol, d.err
47 }
48
49 type fakeConnection struct {
50 closed bool
51 closeChan chan bool
52 dataStream *fakeStream
53 errorStream *fakeStream
54 streamCount int
55 }
56
57 func newFakeConnection() *fakeConnection {
58 return &fakeConnection{
59 closeChan: make(chan bool),
60 dataStream: &fakeStream{},
61 errorStream: &fakeStream{},
62 }
63 }
64
65 func (c *fakeConnection) CreateStream(headers http.Header) (httpstream.Stream, error) {
66 switch headers.Get(v1.StreamType) {
67 case v1.StreamTypeData:
68 c.streamCount++
69 return c.dataStream, nil
70 case v1.StreamTypeError:
71 c.streamCount++
72 return c.errorStream, nil
73 default:
74 return nil, fmt.Errorf("fakeStream creation not supported for stream type %s", headers.Get(v1.StreamType))
75 }
76 }
77
78 func (c *fakeConnection) Close() error {
79 if !c.closed {
80 c.closed = true
81 close(c.closeChan)
82 }
83 return nil
84 }
85
86 func (c *fakeConnection) CloseChan() <-chan bool {
87 return c.closeChan
88 }
89
90 func (c *fakeConnection) RemoveStreams(streams ...httpstream.Stream) {
91 for range streams {
92 c.streamCount--
93 }
94 }
95
96 func (c *fakeConnection) SetIdleTimeout(timeout time.Duration) {
97
98 }
99
100 type fakeListener struct {
101 net.Listener
102 closeChan chan bool
103 }
104
105 func newFakeListener() fakeListener {
106 return fakeListener{
107 closeChan: make(chan bool),
108 }
109 }
110
111 func (l *fakeListener) Accept() (net.Conn, error) {
112 select {
113 case <-l.closeChan:
114 return nil, fmt.Errorf("listener closed")
115 }
116 }
117
118 func (l *fakeListener) Close() error {
119 close(l.closeChan)
120 return nil
121 }
122
123 func (l *fakeListener) Addr() net.Addr {
124 return fakeAddr{}
125 }
126
127 type fakeAddr struct{}
128
129 func (fakeAddr) Network() string { return "fake" }
130 func (fakeAddr) String() string { return "fake" }
131
132 type fakeStream struct {
133 headers http.Header
134 readFunc func(p []byte) (int, error)
135 writeFunc func(p []byte) (int, error)
136 }
137
138 func (s *fakeStream) Read(p []byte) (n int, err error) { return s.readFunc(p) }
139 func (s *fakeStream) Write(p []byte) (n int, err error) { return s.writeFunc(p) }
140 func (*fakeStream) Close() error { return nil }
141 func (*fakeStream) Reset() error { return nil }
142 func (s *fakeStream) Headers() http.Header { return s.headers }
143 func (*fakeStream) Identifier() uint32 { return 0 }
144
145 type fakeConn struct {
146 sendBuffer *bytes.Buffer
147 receiveBuffer *bytes.Buffer
148 }
149
150 func (f fakeConn) Read(p []byte) (int, error) { return f.sendBuffer.Read(p) }
151 func (f fakeConn) Write(p []byte) (int, error) { return f.receiveBuffer.Write(p) }
152 func (fakeConn) Close() error { return nil }
153 func (fakeConn) LocalAddr() net.Addr { return nil }
154 func (fakeConn) RemoteAddr() net.Addr { return nil }
155 func (fakeConn) SetDeadline(t time.Time) error { return nil }
156 func (fakeConn) SetReadDeadline(t time.Time) error { return nil }
157 func (fakeConn) SetWriteDeadline(t time.Time) error { return nil }
158
159 func TestParsePortsAndNew(t *testing.T) {
160 tests := []struct {
161 input []string
162 addresses []string
163 expectedPorts []ForwardedPort
164 expectedAddresses []listenAddress
165 expectPortParseError bool
166 expectAddressParseError bool
167 expectNewError bool
168 }{
169 {input: []string{}, expectNewError: true},
170 {input: []string{"a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
171 {input: []string{":a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
172 {input: []string{"-1"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
173 {input: []string{"65536"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
174 {input: []string{"0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
175 {input: []string{"0:0"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
176 {input: []string{"a:5000"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
177 {input: []string{"5000:a"}, expectPortParseError: true, expectAddressParseError: false, expectNewError: true},
178 {input: []string{"5000:5000"}, addresses: []string{"127.0.0.257"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
179 {input: []string{"5000:5000"}, addresses: []string{"::g"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
180 {input: []string{"5000:5000"}, addresses: []string{"domain.invalid"}, expectPortParseError: false, expectAddressParseError: true, expectNewError: true},
181 {
182 input: []string{"5000:5000"},
183 addresses: []string{"localhost"},
184 expectedPorts: []ForwardedPort{
185 {5000, 5000},
186 },
187 expectedAddresses: []listenAddress{
188 {protocol: "tcp4", address: "127.0.0.1", failureMode: "all"},
189 {protocol: "tcp6", address: "::1", failureMode: "all"},
190 },
191 },
192 {
193 input: []string{"5000:5000"},
194 addresses: []string{"localhost", "127.0.0.1"},
195 expectedPorts: []ForwardedPort{
196 {5000, 5000},
197 },
198 expectedAddresses: []listenAddress{
199 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
200 {protocol: "tcp6", address: "::1", failureMode: "all"},
201 },
202 },
203 {
204 input: []string{"5000:5000"},
205 addresses: []string{"localhost", "::1"},
206 expectedPorts: []ForwardedPort{
207 {5000, 5000},
208 },
209 expectedAddresses: []listenAddress{
210 {protocol: "tcp4", address: "127.0.0.1", failureMode: "all"},
211 {protocol: "tcp6", address: "::1", failureMode: "any"},
212 },
213 },
214 {
215 input: []string{"5000:5000"},
216 addresses: []string{"localhost", "127.0.0.1", "::1"},
217 expectedPorts: []ForwardedPort{
218 {5000, 5000},
219 },
220 expectedAddresses: []listenAddress{
221 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
222 {protocol: "tcp6", address: "::1", failureMode: "any"},
223 },
224 },
225 {
226 input: []string{"5000:5000"},
227 addresses: []string{"localhost", "127.0.0.1", "10.10.10.1"},
228 expectedPorts: []ForwardedPort{
229 {5000, 5000},
230 },
231 expectedAddresses: []listenAddress{
232 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
233 {protocol: "tcp6", address: "::1", failureMode: "all"},
234 {protocol: "tcp4", address: "10.10.10.1", failureMode: "any"},
235 },
236 },
237 {
238 input: []string{"5000:5000"},
239 addresses: []string{"127.0.0.1", "::1", "localhost"},
240 expectedPorts: []ForwardedPort{
241 {5000, 5000},
242 },
243 expectedAddresses: []listenAddress{
244 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
245 {protocol: "tcp6", address: "::1", failureMode: "any"},
246 },
247 },
248 {
249 input: []string{"5000:5000"},
250 addresses: []string{"10.0.0.1", "127.0.0.1"},
251 expectedPorts: []ForwardedPort{
252 {5000, 5000},
253 },
254 expectedAddresses: []listenAddress{
255 {protocol: "tcp4", address: "10.0.0.1", failureMode: "any"},
256 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
257 },
258 },
259 {
260 input: []string{"5000", "5000:5000", "8888:5000", "5000:8888", ":5000", "0:5000"},
261 addresses: []string{"127.0.0.1", "::1"},
262 expectedPorts: []ForwardedPort{
263 {5000, 5000},
264 {5000, 5000},
265 {8888, 5000},
266 {5000, 8888},
267 {0, 5000},
268 {0, 5000},
269 },
270 expectedAddresses: []listenAddress{
271 {protocol: "tcp4", address: "127.0.0.1", failureMode: "any"},
272 {protocol: "tcp6", address: "::1", failureMode: "any"},
273 },
274 },
275 }
276
277 for i, test := range tests {
278 parsedPorts, err := parsePorts(test.input)
279 haveError := err != nil
280 if e, a := test.expectPortParseError, haveError; e != a {
281 t.Fatalf("%d: parsePorts: error expected=%t, got %t: %s", i, e, a, err)
282 }
283
284
285 if len(test.addresses) == 0 && len(test.expectedAddresses) == 0 {
286 test.addresses = []string{"localhost"}
287 test.expectedAddresses = []listenAddress{{protocol: "tcp4", address: "127.0.0.1"}, {protocol: "tcp6", address: "::1"}}
288 }
289
290 parsedAddresses, err := parseAddresses(test.addresses)
291 haveError = err != nil
292 if e, a := test.expectAddressParseError, haveError; e != a {
293 t.Fatalf("%d: parseAddresses: error expected=%t, got %t: %s", i, e, a, err)
294 }
295
296 dialer := &fakeDialer{}
297 expectedStopChan := make(chan struct{})
298 readyChan := make(chan struct{})
299
300 var pf *PortForwarder
301 if len(test.addresses) > 0 {
302 pf, err = NewOnAddresses(dialer, test.addresses, test.input, expectedStopChan, readyChan, os.Stdout, os.Stderr)
303 } else {
304 pf, err = New(dialer, test.input, expectedStopChan, readyChan, os.Stdout, os.Stderr)
305 }
306 haveError = err != nil
307 if e, a := test.expectNewError, haveError; e != a {
308 t.Fatalf("%d: New: error expected=%t, got %t: %s", i, e, a, err)
309 }
310
311 if test.expectPortParseError || test.expectAddressParseError || test.expectNewError {
312 continue
313 }
314
315 sort.Slice(test.expectedAddresses, func(i, j int) bool { return test.expectedAddresses[i].address < test.expectedAddresses[j].address })
316 sort.Slice(parsedAddresses, func(i, j int) bool { return parsedAddresses[i].address < parsedAddresses[j].address })
317
318 if !reflect.DeepEqual(test.expectedAddresses, parsedAddresses) {
319 t.Fatalf("%d: expectedAddresses: %v, got: %v", i, test.expectedAddresses, parsedAddresses)
320 }
321
322 for pi, expectedPort := range test.expectedPorts {
323 if e, a := expectedPort.Local, parsedPorts[pi].Local; e != a {
324 t.Fatalf("%d: local expected: %d, got: %d", i, e, a)
325 }
326 if e, a := expectedPort.Remote, parsedPorts[pi].Remote; e != a {
327 t.Fatalf("%d: remote expected: %d, got: %d", i, e, a)
328 }
329 }
330
331 if dialer.dialed {
332 t.Fatalf("%d: expected not dialed", i)
333 }
334 if _, portErr := pf.GetPorts(); portErr == nil {
335 t.Fatalf("%d: GetPorts: error expected but got nil", i)
336 }
337
338
339 close(readyChan)
340
341 if ports, portErr := pf.GetPorts(); portErr != nil {
342 t.Fatalf("%d: GetPorts: unable to retrieve ports: %s", i, portErr)
343 } else if !reflect.DeepEqual(test.expectedPorts, ports) {
344 t.Fatalf("%d: ports: expected %#v, got %#v", i, test.expectedPorts, ports)
345 }
346 if e, a := expectedStopChan, pf.stopChan; e != a {
347 t.Fatalf("%d: stopChan: expected %#v, got %#v", i, e, a)
348 }
349 if pf.Ready == nil {
350 t.Fatalf("%d: Ready should be non-nil", i)
351 }
352 }
353 }
354
355 type GetListenerTestCase struct {
356 Hostname string
357 Protocol string
358 ShouldRaiseError bool
359 ExpectedListenerAddress string
360 }
361
362 func TestGetListener(t *testing.T) {
363 var pf PortForwarder
364 testCases := []GetListenerTestCase{
365 {
366 Hostname: "localhost",
367 Protocol: "tcp4",
368 ShouldRaiseError: false,
369 ExpectedListenerAddress: "127.0.0.1",
370 },
371 {
372 Hostname: "127.0.0.1",
373 Protocol: "tcp4",
374 ShouldRaiseError: false,
375 ExpectedListenerAddress: "127.0.0.1",
376 },
377 {
378 Hostname: "::1",
379 Protocol: "tcp6",
380 ShouldRaiseError: false,
381 ExpectedListenerAddress: "::1",
382 },
383 {
384 Hostname: "::1",
385 Protocol: "tcp4",
386 ShouldRaiseError: true,
387 },
388 {
389 Hostname: "127.0.0.1",
390 Protocol: "tcp6",
391 ShouldRaiseError: true,
392 },
393 }
394
395 for i, testCase := range testCases {
396 forwardedPort := &ForwardedPort{Local: 0, Remote: 12345}
397 listener, err := pf.getListener(testCase.Protocol, testCase.Hostname, forwardedPort)
398 if err != nil && strings.Contains(err.Error(), "cannot assign requested address") {
399 t.Logf("Can't test #%d: %v", i, err)
400 continue
401 }
402 expectedListenerPort := fmt.Sprintf("%d", forwardedPort.Local)
403 errorRaised := err != nil
404
405 if testCase.ShouldRaiseError != errorRaised {
406 t.Errorf("Test case #%d failed: Data %v an error has been raised(%t) where it should not (or reciprocally): %v", i, testCase, testCase.ShouldRaiseError, err)
407 continue
408 }
409 if errorRaised {
410 continue
411 }
412
413 if listener == nil {
414 t.Errorf("Test case #%d did not raise an error but failed in initializing listener", i)
415 continue
416 }
417
418 host, port, _ := net.SplitHostPort(listener.Addr().String())
419 t.Logf("Asked a %s forward for: %s:0, got listener %s:%s, expected: %s", testCase.Protocol, testCase.Hostname, host, port, expectedListenerPort)
420 if host != testCase.ExpectedListenerAddress {
421 t.Errorf("Test case #%d failed: Listener does not listen on expected address: asked '%v' got '%v'", i, testCase.ExpectedListenerAddress, host)
422 }
423 if port != expectedListenerPort {
424 t.Errorf("Test case #%d failed: Listener does not listen on expected port: asked %v got %v", i, expectedListenerPort, port)
425
426 }
427 listener.Close()
428 }
429 }
430
431 func TestGetPortsReturnsDynamicallyAssignedLocalPort(t *testing.T) {
432 dialer := &fakeDialer{
433 conn: newFakeConnection(),
434 negotiatedProtocol: PortForwardProtocolV1Name,
435 }
436
437 stopChan := make(chan struct{})
438 readyChan := make(chan struct{})
439 errChan := make(chan error)
440
441 defer func() {
442 close(stopChan)
443
444 forwardErr := <-errChan
445 if forwardErr != nil {
446 t.Fatalf("ForwardPorts returned error: %s", forwardErr)
447 }
448 }()
449
450 pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr)
451
452 if err != nil {
453 t.Fatalf("error while calling New: %s", err)
454 }
455
456 go func() {
457 errChan <- pf.ForwardPorts()
458 close(errChan)
459 }()
460
461 <-pf.Ready
462
463 ports, err := pf.GetPorts()
464 if err != nil {
465 t.Fatalf("Failed to get ports. error: %v", err)
466 }
467
468 if len(ports) != 1 {
469 t.Fatalf("expected 1 port, got %d", len(ports))
470 }
471
472 port := ports[0]
473 if port.Local == 0 {
474 t.Fatalf("local port is 0, expected != 0")
475 }
476 }
477
478 func TestHandleConnection(t *testing.T) {
479 out := bytes.NewBufferString("")
480
481 pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, nil)
482 if err != nil {
483 t.Fatalf("error while calling New: %s", err)
484 }
485
486
487 localConnection := &fakeConn{
488 sendBuffer: bytes.NewBufferString("test data from local"),
489 receiveBuffer: bytes.NewBufferString(""),
490 }
491
492
493 remoteDataToSend := bytes.NewBufferString("test data from remote")
494 remoteDataReceived := bytes.NewBufferString("")
495 remoteErrorToSend := bytes.NewBufferString("")
496 blockRemoteSend := make(chan struct{})
497 remoteConnection := newFakeConnection()
498 remoteConnection.dataStream.readFunc = func(p []byte) (int, error) {
499 <-blockRemoteSend
500 return remoteDataToSend.Read(p)
501 }
502 remoteConnection.dataStream.writeFunc = func(p []byte) (int, error) {
503 n, err := remoteDataReceived.Write(p)
504 if remoteDataReceived.String() == "test data from local" {
505 close(blockRemoteSend)
506 }
507 return n, err
508 }
509 remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
510 pf.streamConn = remoteConnection
511
512
513 pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
514 assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
515 assert.Equal(t, "test data from local", remoteDataReceived.String())
516 assert.Equal(t, "test data from remote", localConnection.receiveBuffer.String())
517 assert.Equal(t, "Handling connection for 1111\n", out.String())
518 }
519
520 func TestHandleConnectionSendsRemoteError(t *testing.T) {
521 out := bytes.NewBufferString("")
522 errOut := bytes.NewBufferString("")
523
524 pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
525 if err != nil {
526 t.Fatalf("error while calling New: %s", err)
527 }
528
529
530 localConnection := &fakeConn{
531 sendBuffer: bytes.NewBufferString(""),
532 receiveBuffer: bytes.NewBufferString(""),
533 }
534
535
536 remoteDataToSend := bytes.NewBufferString("")
537 remoteDataReceived := bytes.NewBufferString("")
538 remoteErrorToSend := bytes.NewBufferString("error")
539 remoteConnection := newFakeConnection()
540 remoteConnection.dataStream.readFunc = remoteDataToSend.Read
541 remoteConnection.dataStream.writeFunc = remoteDataReceived.Write
542 remoteConnection.errorStream.readFunc = remoteErrorToSend.Read
543 pf.streamConn = remoteConnection
544
545
546 pf.handleConnection(localConnection, ForwardedPort{Local: 1111, Remote: 2222})
547
548 assert.Equal(t, 0, remoteConnection.streamCount, "stream count should be zero")
549 assert.Equal(t, "", remoteDataReceived.String())
550 assert.Equal(t, "", localConnection.receiveBuffer.String())
551 assert.Equal(t, "Handling connection for 1111\n", out.String())
552 }
553
554 func TestWaitForConnectionExitsOnStreamConnClosed(t *testing.T) {
555 out := bytes.NewBufferString("")
556 errOut := bytes.NewBufferString("")
557
558 pf, err := New(&fakeDialer{}, []string{":2222"}, nil, nil, out, errOut)
559 if err != nil {
560 t.Fatalf("error while calling New: %s", err)
561 }
562
563 listener := newFakeListener()
564
565 pf.streamConn = newFakeConnection()
566 pf.streamConn.Close()
567
568 port := ForwardedPort{}
569 pf.waitForConnection(&listener, port)
570 }
571
572 func TestForwardPortsReturnsErrorWhenConnectionIsLost(t *testing.T) {
573 dialer := &fakeDialer{
574 conn: newFakeConnection(),
575 negotiatedProtocol: PortForwardProtocolV1Name,
576 }
577
578 stopChan := make(chan struct{})
579 readyChan := make(chan struct{})
580 errChan := make(chan error)
581
582 pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr)
583 if err != nil {
584 t.Fatalf("failed to create new PortForwarder: %s", err)
585 }
586
587 go func() {
588 errChan <- pf.ForwardPorts()
589 }()
590
591 <-pf.Ready
592
593
594 pf.streamConn.Close()
595
596 err = <-errChan
597 if err == nil {
598 t.Fatalf("unexpected non-error from pf.ForwardPorts()")
599 } else if err != ErrLostConnectionToPod {
600 t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err)
601 }
602 }
603
604 func TestForwardPortsReturnsNilWhenStopChanIsClosed(t *testing.T) {
605 dialer := &fakeDialer{
606 conn: newFakeConnection(),
607 negotiatedProtocol: PortForwardProtocolV1Name,
608 }
609
610 stopChan := make(chan struct{})
611 readyChan := make(chan struct{})
612 errChan := make(chan error)
613
614 pf, err := New(dialer, []string{":5000"}, stopChan, readyChan, os.Stdout, os.Stderr)
615 if err != nil {
616 t.Fatalf("failed to create new PortForwarder: %s", err)
617 }
618
619 go func() {
620 errChan <- pf.ForwardPorts()
621 }()
622
623 <-pf.Ready
624
625
626
627 close(stopChan)
628
629 err = <-errChan
630 if err != nil {
631 t.Fatalf("unexpected error from pf.ForwardPorts(): %s", err)
632 }
633 }
634
View as plain text