1
16
17 package tests
18
19 import (
20 "bytes"
21 "context"
22 "fmt"
23 "io"
24 "net"
25 "net/http"
26 "net/http/httptest"
27 "net/url"
28 "os"
29 "strings"
30 "sync"
31 "testing"
32 "time"
33
34 "k8s.io/apimachinery/pkg/types"
35 restclient "k8s.io/client-go/rest"
36 . "k8s.io/client-go/tools/portforward"
37 "k8s.io/client-go/transport/spdy"
38 "k8s.io/kubelet/pkg/cri/streaming/portforward"
39 )
40
41
42
43 type fakePortForwarder struct {
44 lock sync.Mutex
45
46 expected map[int32]string
47
48 received map[int32]string
49
50 send map[int32]string
51 }
52
53 var _ portforward.PortForwarder = &fakePortForwarder{}
54
55 func (pf *fakePortForwarder) PortForward(_ context.Context, name string, uid types.UID, port int32, stream io.ReadWriteCloser) error {
56 defer stream.Close()
57
58
59 received := make([]byte, len(pf.expected[port]))
60 n, err := stream.Read(received)
61 if err != nil {
62 return fmt.Errorf("error reading from client for port %d: %v", port, err)
63 }
64 if n != len(pf.expected[port]) {
65 return fmt.Errorf("unexpected length read from client for port %d: got %d, expected %d. data=%q", port, n, len(pf.expected[port]), string(received))
66 }
67
68
69 pf.lock.Lock()
70 pf.received[port] = string(received)
71 pf.lock.Unlock()
72
73
74 io.Copy(stream, strings.NewReader(pf.send[port]))
75
76 return nil
77 }
78
79
80
81 func fakePortForwardServer(t *testing.T, testName string, serverSends, expectedFromClient map[int32]string) http.HandlerFunc {
82 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
83 pf := &fakePortForwarder{
84 expected: expectedFromClient,
85 received: make(map[int32]string),
86 send: serverSends,
87 }
88 portforward.ServePortForward(w, req, pf, "pod", "uid", nil, 0, 10*time.Second, portforward.SupportedProtocols)
89
90 for port, expected := range expectedFromClient {
91 actual, ok := pf.received[port]
92 if !ok {
93 t.Errorf("%s: server didn't receive any data for port %d", testName, port)
94 continue
95 }
96
97 if expected != actual {
98 t.Errorf("%s: server expected to receive %q, got %q for port %d", testName, expected, actual, port)
99 }
100 }
101
102 for port, actual := range pf.received {
103 if _, ok := expectedFromClient[port]; !ok {
104 t.Errorf("%s: server unexpectedly received %q for port %d", testName, actual, port)
105 }
106 }
107 })
108 }
109
110 func TestForwardPorts(t *testing.T) {
111 tests := map[string]struct {
112 ports []string
113 clientSends map[int32]string
114 serverSends map[int32]string
115 }{
116 "forward 1 port with no data either direction": {
117 ports: []string{":5000"},
118 },
119 "forward 2 ports with bidirectional data": {
120 ports: []string{":5001", ":6000"},
121 clientSends: map[int32]string{
122 5001: "abcd",
123 6000: "ghij",
124 },
125 serverSends: map[int32]string{
126 5001: "1234",
127 6000: "5678",
128 },
129 },
130 }
131
132 for testName, test := range tests {
133 t.Run(testName, func(t *testing.T) {
134 server := httptest.NewServer(fakePortForwardServer(t, testName, test.serverSends, test.clientSends))
135 defer server.Close()
136
137 transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
138 if err != nil {
139 t.Fatal(err)
140 }
141 url, _ := url.Parse(server.URL)
142 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
143
144 stopChan := make(chan struct{}, 1)
145 readyChan := make(chan struct{})
146
147 pf, err := New(dialer, test.ports, stopChan, readyChan, os.Stdout, os.Stderr)
148 if err != nil {
149 t.Fatalf("%s: unexpected error calling New: %v", testName, err)
150 }
151
152 doneChan := make(chan error)
153 go func() {
154 doneChan <- pf.ForwardPorts()
155 }()
156 <-pf.Ready
157
158 forwardedPorts, err := pf.GetPorts()
159 if err != nil {
160 t.Fatal(err)
161 }
162
163 remoteToLocalMap := map[int32]int32{}
164 for _, forwardedPort := range forwardedPorts {
165 remoteToLocalMap[int32(forwardedPort.Remote)] = int32(forwardedPort.Local)
166 }
167
168 clientSend := func(port int32, data string) error {
169 clientConn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", remoteToLocalMap[port]))
170 if err != nil {
171 return fmt.Errorf("%s: error dialing %d: %s", testName, port, err)
172
173 }
174 defer clientConn.Close()
175
176 n, err := clientConn.Write([]byte(data))
177 if err != nil && err != io.EOF {
178 return fmt.Errorf("%s: Error sending data '%s': %s", testName, data, err)
179 }
180 if n == 0 {
181 return fmt.Errorf("%s: unexpected write of 0 bytes", testName)
182 }
183 b := make([]byte, 4)
184 _, err = clientConn.Read(b)
185 if err != nil && err != io.EOF {
186 return fmt.Errorf("%s: Error reading data: %s", testName, err)
187 }
188 if !bytes.Equal([]byte(test.serverSends[port]), b) {
189 return fmt.Errorf("%s: expected to read '%s', got '%s'", testName, test.serverSends[port], b)
190 }
191 return nil
192 }
193 for port, data := range test.clientSends {
194 if err := clientSend(port, data); err != nil {
195 t.Error(err)
196 }
197 }
198
199 close(stopChan)
200
201
202 err = <-doneChan
203 if err != nil {
204 t.Errorf("%s: unexpected error: %s", testName, err)
205 }
206 })
207 }
208
209 }
210
211 func TestForwardPortsReturnsErrorWhenAllBindsFailed(t *testing.T) {
212 server := httptest.NewServer(fakePortForwardServer(t, "allBindsFailed", nil, nil))
213 defer server.Close()
214
215 transport, upgrader, err := spdy.RoundTripperFor(&restclient.Config{})
216 if err != nil {
217 t.Fatal(err)
218 }
219 url, _ := url.Parse(server.URL)
220 dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", url)
221
222 stopChan1 := make(chan struct{}, 1)
223 defer close(stopChan1)
224 readyChan1 := make(chan struct{})
225
226 pf1, err := New(dialer, []string{":5555"}, stopChan1, readyChan1, os.Stdout, os.Stderr)
227 if err != nil {
228 t.Fatalf("error creating pf1: %v", err)
229 }
230 go pf1.ForwardPorts()
231 <-pf1.Ready
232
233 forwardedPorts, err := pf1.GetPorts()
234 if err != nil {
235 t.Fatal(err)
236 }
237 if len(forwardedPorts) != 1 {
238 t.Fatalf("expected 1 forwarded port, got %#v", forwardedPorts)
239 }
240 duplicateSpec := fmt.Sprintf("%d:%d", forwardedPorts[0].Local, forwardedPorts[0].Remote)
241
242 stopChan2 := make(chan struct{}, 1)
243 readyChan2 := make(chan struct{})
244 pf2, err := New(dialer, []string{duplicateSpec}, stopChan2, readyChan2, os.Stdout, os.Stderr)
245 if err != nil {
246 t.Fatalf("error creating pf2: %v", err)
247 }
248 if err := pf2.ForwardPorts(); err == nil {
249 t.Fatal("expected non-nil error for pf2.ForwardPorts")
250 }
251 }
252
View as plain text