1
16
17 package server
18
19 import (
20 "encoding/binary"
21 "fmt"
22 "io"
23 "strconv"
24 "sync"
25 "testing"
26
27 "github.com/stretchr/testify/assert"
28 "github.com/stretchr/testify/require"
29 "golang.org/x/net/websocket"
30
31 "k8s.io/apimachinery/pkg/types"
32 "k8s.io/kubelet/pkg/cri/streaming/portforward"
33 )
34
35 const (
36 dataChannel = iota
37 errorChannel
38 )
39
40 func TestServeWSPortForward(t *testing.T) {
41 tests := map[string]struct {
42 port string
43 uid bool
44 clientData string
45 containerData string
46 shouldError bool
47 }{
48 "no port": {port: "", shouldError: true},
49 "none number port": {port: "abc", shouldError: true},
50 "negative port": {port: "-1", shouldError: true},
51 "too large port": {port: "65536", shouldError: true},
52 "0 port": {port: "0", shouldError: true},
53 "min port": {port: "1", shouldError: false},
54 "normal port": {port: "8000", shouldError: false},
55 "normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
56 "max port": {port: "65535", shouldError: false},
57 "normal port with uid": {port: "8000", uid: true, shouldError: false},
58 }
59
60 podNamespace := "other"
61 podName := "foo"
62
63 for desc := range tests {
64 test := tests[desc]
65 t.Run(desc, func(t *testing.T) {
66 ss, err := newTestStreamingServer(0)
67 require.NoError(t, err)
68 defer ss.testHTTPServer.Close()
69 fw := newServerTestWithDebug(true, ss)
70 defer fw.testHTTPServer.Close()
71
72 portForwardFuncDone := make(chan struct{})
73
74 fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
75 assert.Equal(t, podName, name, "pod name")
76 assert.Equal(t, podNamespace, namespace, "pod namespace")
77 if test.uid {
78 assert.Equal(t, testUID, string(uid), "uid")
79 }
80 }
81
82 ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
83 defer close(portForwardFuncDone)
84 assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
85
86 testPort, err := strconv.ParseInt(test.port, 10, 32)
87 require.NoError(t, err, "parse port")
88 assert.Equal(t, int32(testPort), port, "port")
89
90 if test.clientData != "" {
91 fromClient := make([]byte, 32)
92 n, err := stream.Read(fromClient)
93 assert.NoError(t, err, "reading client data")
94 assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data")
95 }
96
97 if test.containerData != "" {
98 _, err := stream.Write([]byte(test.containerData))
99 assert.NoError(t, err, "writing container data")
100 }
101
102 return nil
103 }
104
105 var url string
106 if test.uid {
107 url = fmt.Sprintf("ws://%s/portForward/%s/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, testUID, test.port)
108 } else {
109 url = fmt.Sprintf("ws://%s/portForward/%s/%s?port=%s", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName, test.port)
110 }
111
112 ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
113 assert.Equal(t, test.shouldError, err != nil, "websocket dial")
114 if test.shouldError {
115 return
116 }
117 defer ws.Close()
118
119 p, err := strconv.ParseUint(test.port, 10, 16)
120 require.NoError(t, err, "parse port")
121 p16 := uint16(p)
122
123 channel, data, err := wsRead(ws)
124 require.NoError(t, err, "read")
125 assert.Equal(t, dataChannel, int(channel), "channel")
126 assert.Len(t, data, binary.Size(p16), "data size")
127 assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
128
129 channel, data, err = wsRead(ws)
130 assert.NoError(t, err, "read")
131 assert.Equal(t, errorChannel, int(channel), "channel")
132 assert.Len(t, data, binary.Size(p16), "data size")
133 assert.Equal(t, p16, binary.LittleEndian.Uint16(data), "data")
134
135 if test.clientData != "" {
136 println("writing the client data")
137 err := wsWrite(ws, dataChannel, []byte(test.clientData))
138 assert.NoError(t, err, "writing client data")
139 }
140
141 if test.containerData != "" {
142 _, data, err = wsRead(ws)
143 assert.NoError(t, err, "reading container data")
144 assert.Equal(t, test.containerData, string(data), "container data")
145 }
146
147 <-portForwardFuncDone
148 })
149 }
150 }
151
152 func TestServeWSMultiplePortForward(t *testing.T) {
153 portsText := []string{"7000,8000", "9000"}
154 ports := []uint16{7000, 8000, 9000}
155 podNamespace := "other"
156 podName := "foo"
157
158 ss, err := newTestStreamingServer(0)
159 require.NoError(t, err)
160 defer ss.testHTTPServer.Close()
161 fw := newServerTestWithDebug(true, ss)
162 defer fw.testHTTPServer.Close()
163
164 portForwardWG := sync.WaitGroup{}
165 portForwardWG.Add(len(ports))
166
167 portsMutex := sync.Mutex{}
168 portsForwarded := map[int32]struct{}{}
169
170 fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
171 assert.Equal(t, podName, name, "pod name")
172 assert.Equal(t, podNamespace, namespace, "pod namespace")
173 }
174
175 ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
176 defer portForwardWG.Done()
177 assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
178
179 portsMutex.Lock()
180 portsForwarded[port] = struct{}{}
181 portsMutex.Unlock()
182
183 fromClient := make([]byte, 32)
184 n, err := stream.Read(fromClient)
185 assert.NoError(t, err, "reading client data")
186 assert.Equal(t, fmt.Sprintf("client data on port %d", port), string(fromClient[0:n]), "client data")
187
188 _, err = stream.Write([]byte(fmt.Sprintf("container data on port %d", port)))
189 assert.NoError(t, err, "writing container data")
190
191 return nil
192 }
193
194 url := fmt.Sprintf("ws://%s/portForward/%s/%s?", fw.testHTTPServer.Listener.Addr().String(), podNamespace, podName)
195 for _, port := range portsText {
196 url = url + fmt.Sprintf("port=%s&", port)
197 }
198
199 ws, err := websocket.Dial(url, "", "http://127.0.0.1/")
200 require.NoError(t, err, "websocket dial")
201
202 defer ws.Close()
203
204 for i, port := range ports {
205 channel, data, err := wsRead(ws)
206 assert.NoError(t, err, "port %d read", port)
207 assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
208 assert.Len(t, data, binary.Size(port), "port %d data size", port)
209 assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
210
211 channel, data, err = wsRead(ws)
212 assert.NoError(t, err, "port %d read", port)
213 assert.Equal(t, i*2+errorChannel, int(channel), "port %d channel", port)
214 assert.Len(t, data, binary.Size(port), "port %d data size", port)
215 assert.Equal(t, binary.LittleEndian.Uint16(data), port, "port %d data", port)
216 }
217
218 for i, port := range ports {
219 t.Logf("port %d writing the client data", port)
220 err := wsWrite(ws, byte(i*2+dataChannel), []byte(fmt.Sprintf("client data on port %d", port)))
221 assert.NoError(t, err, "port %d write client data", port)
222
223 channel, data, err := wsRead(ws)
224 assert.NoError(t, err, "port %d read container data", port)
225 assert.Equal(t, i*2+dataChannel, int(channel), "port %d channel", port)
226 assert.Equal(t, fmt.Sprintf("container data on port %d", port), string(data), "port %d container data", port)
227 }
228
229 portForwardWG.Wait()
230
231 portsMutex.Lock()
232 defer portsMutex.Unlock()
233 assert.Len(t, portsForwarded, len(ports), "all ports forwarded")
234 }
235
236 func wsWrite(conn *websocket.Conn, channel byte, data []byte) error {
237 frame := make([]byte, len(data)+1)
238 frame[0] = channel
239 copy(frame[1:], data)
240 err := websocket.Message.Send(conn, frame)
241 return err
242 }
243
244 func wsRead(conn *websocket.Conn) (byte, []byte, error) {
245 for {
246 var data []byte
247 err := websocket.Message.Receive(conn, &data)
248 if err != nil {
249 return 0, nil, err
250 }
251
252 if len(data) == 0 {
253 continue
254 }
255
256 channel := data[0]
257 data = data[1:]
258
259 return channel, data, err
260 }
261 }
262
View as plain text