1
16
17 package wsstream
18
19 import (
20 "encoding/base64"
21 "io"
22 "net/http"
23 "net/http/httptest"
24 "reflect"
25 "sync"
26 "testing"
27
28 "github.com/stretchr/testify/assert"
29 "github.com/stretchr/testify/require"
30 "golang.org/x/net/websocket"
31 )
32
33 func newServer(handler http.Handler) (*httptest.Server, string) {
34 server := httptest.NewServer(handler)
35 serverAddr := server.Listener.Addr().String()
36 return server, serverAddr
37 }
38
39 func TestRawConn(t *testing.T) {
40 channels := []ChannelType{ReadWriteChannel, ReadWriteChannel, IgnoreChannel, ReadChannel, WriteChannel}
41 conn := NewConn(NewDefaultChannelProtocols(channels))
42
43 s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
44 conn.Open(w, req)
45 }))
46 defer s.Close()
47
48 client, err := websocket.Dial("ws://"+addr, "", "http://localhost/")
49 if err != nil {
50 t.Fatal(err)
51 }
52 defer client.Close()
53
54 <-conn.ready
55 wg := sync.WaitGroup{}
56
57
58 wg.Add(1)
59 go func() {
60 defer wg.Done()
61 data, err := io.ReadAll(conn.channels[0])
62 if err != nil {
63 t.Error(err)
64 return
65 }
66 if !reflect.DeepEqual(data, []byte("client")) {
67 t.Errorf("unexpected server read: %v", data)
68 }
69 }()
70
71 if n, err := client.Write(append([]byte{0}, []byte("client")...)); err != nil || n != 7 {
72 t.Fatalf("%d: %v", n, err)
73 }
74
75
76 wg.Add(1)
77 go func() {
78 defer wg.Done()
79 if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
80 t.Errorf("%d: %v", n, err)
81 }
82 }()
83
84 data := make([]byte, 1024)
85 if n, err := io.ReadAtLeast(client, data, 6); n != 7 || err != nil {
86 t.Fatalf("%d: %v", n, err)
87 }
88 if !reflect.DeepEqual(data[:7], append([]byte{1}, []byte("server")...)) {
89 t.Errorf("unexpected client read: %v", data[:7])
90 }
91
92
93 if n, err := conn.channels[2].Write([]byte("test")); n != 4 || err != nil {
94 t.Errorf("writes should be ignored")
95 }
96 data = make([]byte, 1024)
97 if n, err := conn.channels[2].Read(data); n != 0 || err != io.EOF {
98 t.Errorf("reads should be ignored")
99 }
100
101
102 if n, err := conn.channels[3].Write([]byte("test")); n != 4 || err != nil {
103 t.Errorf("writes should be ignored")
104 }
105
106
107 data = make([]byte, 1024)
108 if n, err := conn.channels[4].Read(data); n != 0 || err != io.EOF {
109 t.Errorf("reads should be ignored")
110 }
111
112
113 if n, err := client.Write(append([]byte{4}, []byte("ignored")...)); err != nil || n != 8 {
114 t.Fatalf("%d: %v", n, err)
115 }
116
117 client.Close()
118 wg.Wait()
119 }
120
121 func TestBase64Conn(t *testing.T) {
122 conn := NewConn(NewDefaultChannelProtocols([]ChannelType{ReadWriteChannel, ReadWriteChannel}))
123 s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
124 conn.Open(w, req)
125 }))
126 defer s.Close()
127
128 config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
129 if err != nil {
130 t.Fatal(err)
131 }
132 config.Protocol = []string{"base64.channel.k8s.io"}
133 client, err := websocket.DialConfig(config)
134 if err != nil {
135 t.Fatal(err)
136 }
137 defer client.Close()
138
139 <-conn.ready
140 wg := sync.WaitGroup{}
141 wg.Add(1)
142 go func() {
143 defer wg.Done()
144 data, err := io.ReadAll(conn.channels[0])
145 if err != nil {
146 t.Error(err)
147 return
148 }
149 if !reflect.DeepEqual(data, []byte("client")) {
150 t.Errorf("unexpected server read: %s", string(data))
151 }
152 }()
153
154 clientData := base64.StdEncoding.EncodeToString([]byte("client"))
155 if n, err := client.Write(append([]byte{'0'}, clientData...)); err != nil || n != len(clientData)+1 {
156 t.Fatalf("%d: %v", n, err)
157 }
158
159 wg.Add(1)
160 go func() {
161 defer wg.Done()
162 if n, err := conn.channels[1].Write([]byte("server")); err != nil && n != 6 {
163 t.Errorf("%d: %v", n, err)
164 }
165 }()
166
167 data := make([]byte, 1024)
168 if n, err := io.ReadAtLeast(client, data, 9); n != 9 || err != nil {
169 t.Fatalf("%d: %v", n, err)
170 }
171 expect := []byte(base64.StdEncoding.EncodeToString([]byte("server")))
172
173 if !reflect.DeepEqual(data[:9], append([]byte{'1'}, expect...)) {
174 t.Errorf("unexpected client read: %v", data[:9])
175 }
176
177 client.Close()
178 wg.Wait()
179 }
180
181 type versionTest struct {
182 supported map[string]bool
183 requested []string
184 error bool
185 expected string
186 }
187
188 func versionTests() []versionTest {
189 const (
190 binary = true
191 base64 = false
192 )
193 return []versionTest{
194 {
195 supported: nil,
196 requested: []string{"raw"},
197 error: true,
198 },
199 {
200 supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
201 requested: nil,
202 expected: "",
203 },
204 {
205 supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
206 requested: []string{"v1.raw"},
207 error: true,
208 },
209 {
210 supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
211 requested: []string{"v1.raw", "v1.base64"},
212 error: true,
213 }, {
214 supported: map[string]bool{"": binary, "raw": binary, "base64": base64},
215 requested: []string{"v1.raw", "raw"},
216 expected: "raw",
217 },
218 {
219 supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
220 requested: []string{"v1.raw"},
221 expected: "v1.raw",
222 },
223 {
224 supported: map[string]bool{"": binary, "v1.raw": binary, "v1.base64": base64, "v2.raw": binary, "v2.base64": base64},
225 requested: []string{"v2.base64"},
226 expected: "v2.base64",
227 },
228 }
229 }
230
231 func TestVersionedConn(t *testing.T) {
232 for i, test := range versionTests() {
233 func() {
234 supportedProtocols := map[string]ChannelProtocolConfig{}
235 for p, binary := range test.supported {
236 supportedProtocols[p] = ChannelProtocolConfig{
237 Binary: binary,
238 Channels: []ChannelType{ReadWriteChannel},
239 }
240 }
241 conn := NewConn(supportedProtocols)
242
243
244 selectedProtocol := make(chan string)
245 s, addr := newServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
246 p, _, _ := conn.Open(w, req)
247 selectedProtocol <- p
248 }))
249 defer s.Close()
250
251 config, err := websocket.NewConfig("ws://"+addr, "http://localhost/")
252 if err != nil {
253 t.Fatal(err)
254 }
255 config.Protocol = test.requested
256 client, err := websocket.DialConfig(config)
257 if err != nil {
258 if !test.error {
259 t.Fatalf("test %d: didn't expect error: %v", i, err)
260 } else {
261 return
262 }
263 }
264 defer client.Close()
265 if test.error && err == nil {
266 t.Fatalf("test %d: expected an error", i)
267 }
268
269 <-conn.ready
270 if got, expected := <-selectedProtocol, test.expected; got != expected {
271 t.Fatalf("test %d: unexpected protocol version: got=%s expected=%s", i, got, expected)
272 }
273 }()
274 }
275 }
276
277 func TestIsWebSocketRequestWithStreamCloseProtocol(t *testing.T) {
278 tests := map[string]struct {
279 headers map[string]string
280 expected bool
281 }{
282 "No headers returns false": {
283 headers: map[string]string{},
284 expected: false,
285 },
286 "Only connection upgrade header is false": {
287 headers: map[string]string{
288 "Connection": "upgrade",
289 },
290 expected: false,
291 },
292 "Only websocket upgrade header is false": {
293 headers: map[string]string{
294 "Upgrade": "websocket",
295 },
296 expected: false,
297 },
298 "Only websocket and connection upgrade headers is false": {
299 headers: map[string]string{
300 "Connection": "upgrade",
301 "Upgrade": "websocket",
302 },
303 expected: false,
304 },
305 "Missing connection/upgrade header is false": {
306 headers: map[string]string{
307 "Upgrade": "websocket",
308 WebSocketProtocolHeader: "v5.channel.k8s.io",
309 },
310 expected: false,
311 },
312 "Websocket connection upgrade headers with v5 protocol is true": {
313 headers: map[string]string{
314 "Connection": "upgrade",
315 "Upgrade": "websocket",
316 WebSocketProtocolHeader: "v5.channel.k8s.io",
317 },
318 expected: true,
319 },
320 "Websocket connection upgrade headers with wrong case v5 protocol is false": {
321 headers: map[string]string{
322 "Connection": "upgrade",
323 "Upgrade": "websocket",
324 WebSocketProtocolHeader: "v5.CHANNEL.k8s.io",
325 },
326 expected: false,
327 },
328 "Websocket connection upgrade headers with v4 protocol is false": {
329 headers: map[string]string{
330 "Connection": "upgrade",
331 "Upgrade": "websocket",
332 WebSocketProtocolHeader: "v4.channel.k8s.io",
333 },
334 expected: false,
335 },
336 "Websocket connection upgrade headers with multiple protocols but missing v5 is false": {
337 headers: map[string]string{
338 "Connection": "upgrade",
339 "Upgrade": "websocket",
340 WebSocketProtocolHeader: "v4.channel.k8s.io,v3.channel.k8s.io,v2.channel.k8s.io",
341 },
342 expected: false,
343 },
344 "Websocket connection upgrade headers with multiple protocols including v5 and spaces is true": {
345 headers: map[string]string{
346 "Connection": "upgrade",
347 "Upgrade": "websocket",
348 WebSocketProtocolHeader: "v5.channel.k8s.io, v4.channel.k8s.io",
349 },
350 expected: true,
351 },
352 "Websocket connection upgrade headers with multiple protocols out of order including v5 and spaces is true": {
353 headers: map[string]string{
354 "Connection": "upgrade",
355 "Upgrade": "websocket",
356 WebSocketProtocolHeader: "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io",
357 },
358 expected: true,
359 },
360
361 "Websocket connection upgrade headers key is case-insensitive": {
362 headers: map[string]string{
363 "Connection": "upgrade",
364 "Upgrade": "websocket",
365 "sec-websocket-protocol": "v4.channel.k8s.io, v5.channel.k8s.io, v3.channel.k8s.io",
366 },
367 expected: true,
368 },
369 }
370
371 for name, test := range tests {
372 req, err := http.NewRequest("GET", "http://www.example.com/", nil)
373 require.NoError(t, err)
374 for key, value := range test.headers {
375 req.Header.Add(key, value)
376 }
377 actual := IsWebSocketRequestWithStreamCloseProtocol(req)
378 assert.Equal(t, test.expected, actual, "%s: expected (%t), got (%t)", name, test.expected, actual)
379 }
380 }
381
382 func TestProtocolSupportsStreamClose(t *testing.T) {
383 tests := map[string]struct {
384 protocol string
385 expected bool
386 }{
387 "empty protocol returns false": {
388 protocol: "",
389 expected: false,
390 },
391 "not binary protocol returns false": {
392 protocol: "base64.channel.k8s.io",
393 expected: false,
394 },
395 "V1 protocol returns false": {
396 protocol: "channel.k8s.io",
397 expected: false,
398 },
399 "V4 protocol returns false": {
400 protocol: "v4.channel.k8s.io",
401 expected: false,
402 },
403 "V5 protocol returns true": {
404 protocol: "v5.channel.k8s.io",
405 expected: true,
406 },
407 "V5 protocol wrong case returns false": {
408 protocol: "V5.channel.K8S.io",
409 expected: false,
410 },
411 }
412
413 for name, test := range tests {
414 actual := protocolSupportsStreamClose(test.protocol)
415 assert.Equal(t, test.expected, actual,
416 "%s: expected (%t), got (%t)", name, test.expected, actual)
417 }
418 }
419
View as plain text