1
2
3
4 package websocket_test
5
6 import (
7 "bytes"
8 "context"
9 "crypto/rand"
10 "io"
11 "net/http"
12 "net/http/httptest"
13 "net/url"
14 "strings"
15 "testing"
16 "time"
17
18 "nhooyr.io/websocket"
19 "nhooyr.io/websocket/internal/test/assert"
20 "nhooyr.io/websocket/internal/util"
21 "nhooyr.io/websocket/internal/xsync"
22 )
23
24 func TestBadDials(t *testing.T) {
25 t.Parallel()
26
27 t.Run("badReq", func(t *testing.T) {
28 t.Parallel()
29
30 testCases := []struct {
31 name string
32 url string
33 opts *websocket.DialOptions
34 rand util.ReaderFunc
35 nilCtx bool
36 }{
37 {
38 name: "badURL",
39 url: "://noscheme",
40 },
41 {
42 name: "badURLScheme",
43 url: "ftp://nhooyr.io",
44 },
45 {
46 name: "badTLS",
47 url: "wss://totallyfake.nhooyr.io",
48 },
49 {
50 name: "badReader",
51 rand: func(p []byte) (int, error) {
52 return 0, io.EOF
53 },
54 },
55 {
56 name: "nilContext",
57 url: "http://localhost",
58 nilCtx: true,
59 },
60 }
61
62 for _, tc := range testCases {
63 tc := tc
64 t.Run(tc.name, func(t *testing.T) {
65 t.Parallel()
66
67 var ctx context.Context
68 var cancel func()
69 if !tc.nilCtx {
70 ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
71 defer cancel()
72 }
73
74 if tc.rand == nil {
75 tc.rand = rand.Reader.Read
76 }
77
78 _, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
79 assert.Error(t, err)
80 })
81 }
82 })
83
84 t.Run("badResponse", func(t *testing.T) {
85 t.Parallel()
86
87 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
88 defer cancel()
89
90 _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
91 HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
92 return &http.Response{
93 Body: io.NopCloser(strings.NewReader("hi")),
94 }, nil
95 }),
96 })
97 assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
98 })
99
100 t.Run("badBody", func(t *testing.T) {
101 t.Parallel()
102
103 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
104 defer cancel()
105
106 rt := func(r *http.Request) (*http.Response, error) {
107 h := http.Header{}
108 h.Set("Connection", "Upgrade")
109 h.Set("Upgrade", "websocket")
110 h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
111
112 return &http.Response{
113 StatusCode: http.StatusSwitchingProtocols,
114 Header: h,
115 Body: io.NopCloser(strings.NewReader("hi")),
116 }, nil
117 }
118
119 _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
120 HTTPClient: mockHTTPClient(rt),
121 })
122 assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
123 })
124 }
125
126 func Test_verifyHostOverride(t *testing.T) {
127 testCases := []struct {
128 name string
129 host string
130 exp string
131 }{
132 {
133 name: "noOverride",
134 host: "",
135 exp: "example.com",
136 },
137 {
138 name: "hostOverride",
139 host: "example.net",
140 exp: "example.net",
141 },
142 }
143
144 for _, tc := range testCases {
145 tc := tc
146 t.Run(tc.name, func(t *testing.T) {
147 t.Parallel()
148
149 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
150 defer cancel()
151
152 rt := func(r *http.Request) (*http.Response, error) {
153 assert.Equal(t, "Host", tc.exp, r.Host)
154
155 h := http.Header{}
156 h.Set("Connection", "Upgrade")
157 h.Set("Upgrade", "websocket")
158 h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
159
160 return &http.Response{
161 StatusCode: http.StatusSwitchingProtocols,
162 Header: h,
163 Body: mockBody{bytes.NewBufferString("hi")},
164 }, nil
165 }
166
167 c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
168 HTTPClient: mockHTTPClient(rt),
169 Host: tc.host,
170 })
171 assert.Success(t, err)
172 c.CloseNow()
173 })
174 }
175
176 }
177
178 type mockBody struct {
179 *bytes.Buffer
180 }
181
182 func (mb mockBody) Close() error {
183 return nil
184 }
185
186 func Test_verifyServerHandshake(t *testing.T) {
187 t.Parallel()
188
189 testCases := []struct {
190 name string
191 response func(w http.ResponseWriter)
192 success bool
193 }{
194 {
195 name: "badStatus",
196 response: func(w http.ResponseWriter) {
197 w.WriteHeader(http.StatusOK)
198 },
199 success: false,
200 },
201 {
202 name: "badConnection",
203 response: func(w http.ResponseWriter) {
204 w.Header().Set("Connection", "???")
205 w.WriteHeader(http.StatusSwitchingProtocols)
206 },
207 success: false,
208 },
209 {
210 name: "badUpgrade",
211 response: func(w http.ResponseWriter) {
212 w.Header().Set("Connection", "Upgrade")
213 w.Header().Set("Upgrade", "???")
214 w.WriteHeader(http.StatusSwitchingProtocols)
215 },
216 success: false,
217 },
218 {
219 name: "badSecWebSocketAccept",
220 response: func(w http.ResponseWriter) {
221 w.Header().Set("Connection", "Upgrade")
222 w.Header().Set("Upgrade", "websocket")
223 w.Header().Set("Sec-WebSocket-Accept", "xd")
224 w.WriteHeader(http.StatusSwitchingProtocols)
225 },
226 success: false,
227 },
228 {
229 name: "badSecWebSocketProtocol",
230 response: func(w http.ResponseWriter) {
231 w.Header().Set("Connection", "Upgrade")
232 w.Header().Set("Upgrade", "websocket")
233 w.Header().Set("Sec-WebSocket-Protocol", "xd")
234 w.WriteHeader(http.StatusSwitchingProtocols)
235 },
236 success: false,
237 },
238 {
239 name: "unsupportedExtension",
240 response: func(w http.ResponseWriter) {
241 w.Header().Set("Connection", "Upgrade")
242 w.Header().Set("Upgrade", "websocket")
243 w.Header().Set("Sec-WebSocket-Extensions", "meow")
244 w.WriteHeader(http.StatusSwitchingProtocols)
245 },
246 success: false,
247 },
248 {
249 name: "unsupportedDeflateParam",
250 response: func(w http.ResponseWriter) {
251 w.Header().Set("Connection", "Upgrade")
252 w.Header().Set("Upgrade", "websocket")
253 w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
254 w.WriteHeader(http.StatusSwitchingProtocols)
255 },
256 success: false,
257 },
258 {
259 name: "success",
260 response: func(w http.ResponseWriter) {
261 w.Header().Set("Connection", "Upgrade")
262 w.Header().Set("Upgrade", "websocket")
263 w.WriteHeader(http.StatusSwitchingProtocols)
264 },
265 success: true,
266 },
267 }
268
269 for _, tc := range testCases {
270 tc := tc
271 t.Run(tc.name, func(t *testing.T) {
272 t.Parallel()
273
274 w := httptest.NewRecorder()
275 tc.response(w)
276 resp := w.Result()
277
278 r := httptest.NewRequest("GET", "/", nil)
279 key, err := websocket.SecWebSocketKey(rand.Reader)
280 assert.Success(t, err)
281 r.Header.Set("Sec-WebSocket-Key", key)
282
283 if resp.Header.Get("Sec-WebSocket-Accept") == "" {
284 resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
285 }
286
287 opts := &websocket.DialOptions{
288 Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
289 }
290 _, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
291 if tc.success {
292 assert.Success(t, err)
293 } else {
294 assert.Error(t, err)
295 }
296 })
297 }
298 }
299
300 func mockHTTPClient(fn roundTripperFunc) *http.Client {
301 return &http.Client{
302 Transport: fn,
303 }
304 }
305
306 type roundTripperFunc func(*http.Request) (*http.Response, error)
307
308 func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
309 return f(r)
310 }
311
312 func TestDialRedirect(t *testing.T) {
313 t.Parallel()
314
315 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
316 defer cancel()
317
318 _, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
319 HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
320 resp := &http.Response{
321 Header: http.Header{},
322 }
323 if r.URL.Scheme != "https" {
324 resp.Header.Set("Location", "wss://example.com")
325 resp.StatusCode = http.StatusFound
326 return resp, nil
327 }
328 resp.Header.Set("Connection", "Upgrade")
329 resp.Header.Set("Upgrade", "meow")
330 resp.StatusCode = http.StatusSwitchingProtocols
331 return resp, nil
332 }),
333 })
334 assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
335 }
336
337 type forwardProxy struct {
338 hc *http.Client
339 }
340
341 func newForwardProxy() *forwardProxy {
342 return &forwardProxy{
343 hc: &http.Client{},
344 }
345 }
346
347 func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
348 ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
349 defer cancel()
350
351 r = r.WithContext(ctx)
352 r.RequestURI = ""
353 resp, err := fc.hc.Do(r)
354 if err != nil {
355 http.Error(w, err.Error(), http.StatusBadRequest)
356 return
357 }
358 defer resp.Body.Close()
359
360 for k, v := range resp.Header {
361 w.Header()[k] = v
362 }
363 w.Header().Set("PROXIED", "true")
364 w.WriteHeader(resp.StatusCode)
365 if resprw, ok := resp.Body.(io.ReadWriter); ok {
366 c, brw, err := w.(http.Hijacker).Hijack()
367 if err != nil {
368 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
369 return
370 }
371 brw.Flush()
372
373 errc1 := xsync.Go(func() error {
374 _, err := io.Copy(c, resprw)
375 return err
376 })
377 errc2 := xsync.Go(func() error {
378 _, err := io.Copy(resprw, c)
379 return err
380 })
381 select {
382 case <-errc1:
383 case <-errc2:
384 case <-r.Context().Done():
385 }
386 } else {
387 io.Copy(w, resp.Body)
388 }
389 }
390
391 func TestDialViaProxy(t *testing.T) {
392 t.Parallel()
393
394 ps := httptest.NewServer(newForwardProxy())
395 defer ps.Close()
396
397 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
398 err := echoServer(w, r, nil)
399 assert.Success(t, err)
400 }))
401 defer s.Close()
402
403 psu, err := url.Parse(ps.URL)
404 assert.Success(t, err)
405 proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
406 proxyTransport.Proxy = http.ProxyURL(psu)
407
408 ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
409 defer cancel()
410 c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
411 HTTPClient: &http.Client{
412 Transport: proxyTransport,
413 },
414 })
415 assert.Success(t, err)
416 assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
417
418 assertEcho(t, ctx, c)
419 assertClose(t, c)
420 }
421
View as plain text