...

Source file src/nhooyr.io/websocket/conn_test.go

Documentation: nhooyr.io/websocket

     1  //go:build !js
     2  
     3  package websocket_test
     4  
     5  import (
     6  	"bytes"
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"os"
    14  	"os/exec"
    15  	"strings"
    16  	"testing"
    17  	"time"
    18  
    19  	"nhooyr.io/websocket"
    20  	"nhooyr.io/websocket/internal/errd"
    21  	"nhooyr.io/websocket/internal/test/assert"
    22  	"nhooyr.io/websocket/internal/test/wstest"
    23  	"nhooyr.io/websocket/internal/test/xrand"
    24  	"nhooyr.io/websocket/internal/xsync"
    25  	"nhooyr.io/websocket/wsjson"
    26  )
    27  
    28  func TestConn(t *testing.T) {
    29  	t.Parallel()
    30  
    31  	t.Run("fuzzData", func(t *testing.T) {
    32  		t.Parallel()
    33  
    34  		compressionMode := func() websocket.CompressionMode {
    35  			return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1))
    36  		}
    37  
    38  		for i := 0; i < 5; i++ {
    39  			t.Run("", func(t *testing.T) {
    40  				tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
    41  					CompressionMode:      compressionMode(),
    42  					CompressionThreshold: xrand.Int(9999),
    43  				}, &websocket.AcceptOptions{
    44  					CompressionMode:      compressionMode(),
    45  					CompressionThreshold: xrand.Int(9999),
    46  				})
    47  
    48  				tt.goEchoLoop(c2)
    49  
    50  				c1.SetReadLimit(131072)
    51  
    52  				for i := 0; i < 5; i++ {
    53  					err := wstest.Echo(tt.ctx, c1, 131072)
    54  					assert.Success(t, err)
    55  				}
    56  
    57  				err := c1.Close(websocket.StatusNormalClosure, "")
    58  				assert.Success(t, err)
    59  			})
    60  		}
    61  	})
    62  
    63  	t.Run("badClose", func(t *testing.T) {
    64  		tt, c1, c2 := newConnTest(t, nil, nil)
    65  
    66  		c2.CloseRead(tt.ctx)
    67  
    68  		err := c1.Close(-1, "")
    69  		assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
    70  	})
    71  
    72  	t.Run("ping", func(t *testing.T) {
    73  		tt, c1, c2 := newConnTest(t, nil, nil)
    74  
    75  		c1.CloseRead(tt.ctx)
    76  		c2.CloseRead(tt.ctx)
    77  
    78  		for i := 0; i < 10; i++ {
    79  			err := c1.Ping(tt.ctx)
    80  			assert.Success(t, err)
    81  		}
    82  
    83  		err := c1.Close(websocket.StatusNormalClosure, "")
    84  		assert.Success(t, err)
    85  	})
    86  
    87  	t.Run("badPing", func(t *testing.T) {
    88  		tt, c1, c2 := newConnTest(t, nil, nil)
    89  
    90  		c2.CloseRead(tt.ctx)
    91  
    92  		ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
    93  		defer cancel()
    94  
    95  		err := c1.Ping(ctx)
    96  		assert.Contains(t, err, "failed to wait for pong")
    97  	})
    98  
    99  	t.Run("concurrentWrite", func(t *testing.T) {
   100  		tt, c1, c2 := newConnTest(t, nil, nil)
   101  
   102  		tt.goDiscardLoop(c2)
   103  
   104  		msg := xrand.Bytes(xrand.Int(9999))
   105  		const count = 100
   106  		errs := make(chan error, count)
   107  
   108  		for i := 0; i < count; i++ {
   109  			go func() {
   110  				select {
   111  				case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
   112  				case <-tt.ctx.Done():
   113  					return
   114  				}
   115  			}()
   116  		}
   117  
   118  		for i := 0; i < count; i++ {
   119  			select {
   120  			case err := <-errs:
   121  				assert.Success(t, err)
   122  			case <-tt.ctx.Done():
   123  				t.Fatal(tt.ctx.Err())
   124  			}
   125  		}
   126  
   127  		err := c1.Close(websocket.StatusNormalClosure, "")
   128  		assert.Success(t, err)
   129  	})
   130  
   131  	t.Run("concurrentWriteError", func(t *testing.T) {
   132  		tt, c1, _ := newConnTest(t, nil, nil)
   133  
   134  		_, err := c1.Writer(tt.ctx, websocket.MessageText)
   135  		assert.Success(t, err)
   136  
   137  		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
   138  		defer cancel()
   139  
   140  		err = c1.Write(ctx, websocket.MessageText, []byte("x"))
   141  		if !errors.Is(err, context.DeadlineExceeded) {
   142  			t.Fatalf("unexpected error: %#v", err)
   143  		}
   144  	})
   145  
   146  	t.Run("netConn", func(t *testing.T) {
   147  		tt, c1, c2 := newConnTest(t, nil, nil)
   148  
   149  		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
   150  		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
   151  
   152  		// Does not give any confidence but at least ensures no crashes.
   153  		d, _ := tt.ctx.Deadline()
   154  		n1.SetDeadline(d)
   155  		n1.SetDeadline(time.Time{})
   156  
   157  		assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
   158  		assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String())
   159  		assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network())
   160  
   161  		errs := xsync.Go(func() error {
   162  			_, err := n2.Write([]byte("hello"))
   163  			if err != nil {
   164  				return err
   165  			}
   166  			return n2.Close()
   167  		})
   168  
   169  		b, err := io.ReadAll(n1)
   170  		assert.Success(t, err)
   171  
   172  		_, err = n1.Read(nil)
   173  		assert.Equal(t, "read error", err, io.EOF)
   174  
   175  		select {
   176  		case err := <-errs:
   177  			assert.Success(t, err)
   178  		case <-tt.ctx.Done():
   179  			t.Fatal(tt.ctx.Err())
   180  		}
   181  
   182  		assert.Equal(t, "read msg", []byte("hello"), b)
   183  	})
   184  
   185  	t.Run("netConn/BadMsg", func(t *testing.T) {
   186  		tt, c1, c2 := newConnTest(t, nil, nil)
   187  
   188  		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
   189  		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
   190  
   191  		c2.CloseRead(tt.ctx)
   192  		errs := xsync.Go(func() error {
   193  			_, err := n2.Write([]byte("hello"))
   194  			return err
   195  		})
   196  
   197  		_, err := io.ReadAll(n1)
   198  		assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
   199  
   200  		select {
   201  		case err := <-errs:
   202  			assert.Success(t, err)
   203  		case <-tt.ctx.Done():
   204  			t.Fatal(tt.ctx.Err())
   205  		}
   206  	})
   207  
   208  	t.Run("netConn/readLimit", func(t *testing.T) {
   209  		tt, c1, c2 := newConnTest(t, nil, nil)
   210  
   211  		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
   212  		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
   213  
   214  		s := strings.Repeat("papa", 1<<20)
   215  		errs := xsync.Go(func() error {
   216  			_, err := n2.Write([]byte(s))
   217  			if err != nil {
   218  				return err
   219  			}
   220  			return n2.Close()
   221  		})
   222  
   223  		b, err := io.ReadAll(n1)
   224  		assert.Success(t, err)
   225  
   226  		_, err = n1.Read(nil)
   227  		assert.Equal(t, "read error", err, io.EOF)
   228  
   229  		select {
   230  		case err := <-errs:
   231  			assert.Success(t, err)
   232  		case <-tt.ctx.Done():
   233  			t.Fatal(tt.ctx.Err())
   234  		}
   235  
   236  		assert.Equal(t, "read msg", s, string(b))
   237  	})
   238  
   239  	t.Run("netConn/pastDeadline", func(t *testing.T) {
   240  		tt, c1, c2 := newConnTest(t, nil, nil)
   241  
   242  		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
   243  		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
   244  
   245  		n1.SetDeadline(time.Now().Add(-time.Minute))
   246  		n2.SetDeadline(time.Now().Add(-time.Minute))
   247  
   248  		// No panic we're good.
   249  	})
   250  
   251  	t.Run("wsjson", func(t *testing.T) {
   252  		tt, c1, c2 := newConnTest(t, nil, nil)
   253  
   254  		tt.goEchoLoop(c2)
   255  
   256  		c1.SetReadLimit(1 << 30)
   257  
   258  		exp := xrand.String(xrand.Int(131072))
   259  
   260  		werr := xsync.Go(func() error {
   261  			return wsjson.Write(tt.ctx, c1, exp)
   262  		})
   263  
   264  		var act interface{}
   265  		err := wsjson.Read(tt.ctx, c1, &act)
   266  		assert.Success(t, err)
   267  		assert.Equal(t, "read msg", exp, act)
   268  
   269  		select {
   270  		case err := <-werr:
   271  			assert.Success(t, err)
   272  		case <-tt.ctx.Done():
   273  			t.Fatal(tt.ctx.Err())
   274  		}
   275  
   276  		err = c1.Close(websocket.StatusNormalClosure, "")
   277  		assert.Success(t, err)
   278  	})
   279  
   280  	t.Run("HTTPClient.Timeout", func(t *testing.T) {
   281  		tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
   282  			HTTPClient: &http.Client{Timeout: time.Second * 5},
   283  		}, nil)
   284  
   285  		tt.goEchoLoop(c2)
   286  
   287  		c1.SetReadLimit(1 << 30)
   288  
   289  		exp := xrand.String(xrand.Int(131072))
   290  
   291  		werr := xsync.Go(func() error {
   292  			return wsjson.Write(tt.ctx, c1, exp)
   293  		})
   294  
   295  		var act interface{}
   296  		err := wsjson.Read(tt.ctx, c1, &act)
   297  		assert.Success(t, err)
   298  		assert.Equal(t, "read msg", exp, act)
   299  
   300  		select {
   301  		case err := <-werr:
   302  			assert.Success(t, err)
   303  		case <-tt.ctx.Done():
   304  			t.Fatal(tt.ctx.Err())
   305  		}
   306  
   307  		err = c1.Close(websocket.StatusNormalClosure, "")
   308  		assert.Success(t, err)
   309  	})
   310  
   311  	t.Run("CloseNow", func(t *testing.T) {
   312  		_, c1, c2 := newConnTest(t, nil, nil)
   313  
   314  		err1 := c1.CloseNow()
   315  		err2 := c2.CloseNow()
   316  		assert.Success(t, err1)
   317  		assert.Success(t, err2)
   318  		err1 = c1.CloseNow()
   319  		err2 = c2.CloseNow()
   320  		assert.ErrorIs(t, websocket.ErrClosed, err1)
   321  		assert.ErrorIs(t, websocket.ErrClosed, err2)
   322  	})
   323  
   324  	t.Run("MidReadClose", func(t *testing.T) {
   325  		tt, c1, c2 := newConnTest(t, nil, nil)
   326  
   327  		tt.goEchoLoop(c2)
   328  
   329  		c1.SetReadLimit(131072)
   330  
   331  		for i := 0; i < 5; i++ {
   332  			err := wstest.Echo(tt.ctx, c1, 131072)
   333  			assert.Success(t, err)
   334  		}
   335  
   336  		err := wsjson.Write(tt.ctx, c1, "four")
   337  		assert.Success(t, err)
   338  		_, _, err = c1.Reader(tt.ctx)
   339  		assert.Success(t, err)
   340  
   341  		err = c1.Close(websocket.StatusNormalClosure, "")
   342  		assert.Success(t, err)
   343  	})
   344  }
   345  
   346  func TestWasm(t *testing.T) {
   347  	t.Parallel()
   348  
   349  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   350  		err := echoServer(w, r, &websocket.AcceptOptions{
   351  			Subprotocols:       []string{"echo"},
   352  			InsecureSkipVerify: true,
   353  		})
   354  		if err != nil {
   355  			t.Error(err)
   356  		}
   357  	}))
   358  	defer s.Close()
   359  
   360  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   361  	defer cancel()
   362  
   363  	cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".")
   364  	cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
   365  
   366  	b, err := cmd.CombinedOutput()
   367  	if err != nil {
   368  		t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
   369  	}
   370  }
   371  
   372  func assertCloseStatus(exp websocket.StatusCode, err error) error {
   373  	if websocket.CloseStatus(err) == -1 {
   374  		return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
   375  	}
   376  	if websocket.CloseStatus(err) != exp {
   377  		return fmt.Errorf("expected close status %v but got %v", exp, err)
   378  	}
   379  	return nil
   380  }
   381  
   382  type connTest struct {
   383  	t   testing.TB
   384  	ctx context.Context
   385  }
   386  
   387  func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
   388  	if t, ok := t.(*testing.T); ok {
   389  		t.Parallel()
   390  	}
   391  	t.Helper()
   392  
   393  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
   394  	tt = &connTest{t: t, ctx: ctx}
   395  	t.Cleanup(cancel)
   396  
   397  	c1, c2 = wstest.Pipe(dialOpts, acceptOpts)
   398  	if xrand.Bool() {
   399  		c1, c2 = c2, c1
   400  	}
   401  	t.Cleanup(func() {
   402  		c2.CloseNow()
   403  		c1.CloseNow()
   404  	})
   405  
   406  	return tt, c1, c2
   407  }
   408  
   409  func (tt *connTest) goEchoLoop(c *websocket.Conn) {
   410  	ctx, cancel := context.WithCancel(tt.ctx)
   411  
   412  	echoLoopErr := xsync.Go(func() error {
   413  		err := wstest.EchoLoop(ctx, c)
   414  		return assertCloseStatus(websocket.StatusNormalClosure, err)
   415  	})
   416  	tt.t.Cleanup(func() {
   417  		cancel()
   418  		err := <-echoLoopErr
   419  		if err != nil {
   420  			tt.t.Errorf("echo loop error: %v", err)
   421  		}
   422  	})
   423  }
   424  
   425  func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
   426  	ctx, cancel := context.WithCancel(tt.ctx)
   427  
   428  	discardLoopErr := xsync.Go(func() error {
   429  		defer c.Close(websocket.StatusInternalError, "")
   430  
   431  		for {
   432  			_, _, err := c.Read(ctx)
   433  			if err != nil {
   434  				return assertCloseStatus(websocket.StatusNormalClosure, err)
   435  			}
   436  		}
   437  	})
   438  	tt.t.Cleanup(func() {
   439  		cancel()
   440  		err := <-discardLoopErr
   441  		if err != nil {
   442  			tt.t.Errorf("discard loop error: %v", err)
   443  		}
   444  	})
   445  }
   446  
   447  func BenchmarkConn(b *testing.B) {
   448  	var benchCases = []struct {
   449  		name string
   450  		mode websocket.CompressionMode
   451  	}{
   452  		{
   453  			name: "disabledCompress",
   454  			mode: websocket.CompressionDisabled,
   455  		},
   456  		{
   457  			name: "compressContextTakeover",
   458  			mode: websocket.CompressionContextTakeover,
   459  		},
   460  		{
   461  			name: "compressNoContext",
   462  			mode: websocket.CompressionNoContextTakeover,
   463  		},
   464  	}
   465  	for _, bc := range benchCases {
   466  		b.Run(bc.name, func(b *testing.B) {
   467  			bb, c1, c2 := newConnTest(b, &websocket.DialOptions{
   468  				CompressionMode: bc.mode,
   469  			}, &websocket.AcceptOptions{
   470  				CompressionMode: bc.mode,
   471  			})
   472  
   473  			bb.goEchoLoop(c2)
   474  
   475  			bytesWritten := c1.RecordBytesWritten()
   476  			bytesRead := c1.RecordBytesRead()
   477  
   478  			msg := []byte(strings.Repeat("1234", 128))
   479  			readBuf := make([]byte, len(msg))
   480  			writes := make(chan struct{})
   481  			defer close(writes)
   482  			werrs := make(chan error)
   483  
   484  			go func() {
   485  				for range writes {
   486  					select {
   487  					case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
   488  					case <-bb.ctx.Done():
   489  						return
   490  					}
   491  				}
   492  			}()
   493  			b.SetBytes(int64(len(msg)))
   494  			b.ReportAllocs()
   495  			b.ResetTimer()
   496  			for i := 0; i < b.N; i++ {
   497  				select {
   498  				case writes <- struct{}{}:
   499  				case <-bb.ctx.Done():
   500  					b.Fatal(bb.ctx.Err())
   501  				}
   502  
   503  				typ, r, err := c1.Reader(bb.ctx)
   504  				if err != nil {
   505  					b.Fatal(i, err)
   506  				}
   507  				if websocket.MessageText != typ {
   508  					assert.Equal(b, "data type", websocket.MessageText, typ)
   509  				}
   510  
   511  				_, err = io.ReadFull(r, readBuf)
   512  				if err != nil {
   513  					b.Fatal(err)
   514  				}
   515  
   516  				n2, err := r.Read(readBuf)
   517  				if err != io.EOF {
   518  					assert.Equal(b, "read err", io.EOF, err)
   519  				}
   520  				if n2 != 0 {
   521  					assert.Equal(b, "n2", 0, n2)
   522  				}
   523  
   524  				if !bytes.Equal(msg, readBuf) {
   525  					assert.Equal(b, "msg", msg, readBuf)
   526  				}
   527  
   528  				select {
   529  				case err = <-werrs:
   530  				case <-bb.ctx.Done():
   531  					b.Fatal(bb.ctx.Err())
   532  				}
   533  				if err != nil {
   534  					b.Fatal(err)
   535  				}
   536  			}
   537  			b.StopTimer()
   538  
   539  			b.ReportMetric(float64(*bytesWritten/b.N), "written/op")
   540  			b.ReportMetric(float64(*bytesRead/b.N), "read/op")
   541  
   542  			err := c1.Close(websocket.StatusNormalClosure, "")
   543  			assert.Success(b, err)
   544  		})
   545  	}
   546  }
   547  
   548  func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) {
   549  	defer errd.Wrap(&err, "echo server failed")
   550  
   551  	c, err := websocket.Accept(w, r, opts)
   552  	if err != nil {
   553  		return err
   554  	}
   555  	defer c.Close(websocket.StatusInternalError, "")
   556  
   557  	err = wstest.EchoLoop(r.Context(), c)
   558  	return assertCloseStatus(websocket.StatusNormalClosure, err)
   559  }
   560  
   561  func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) {
   562  	exp := xrand.String(xrand.Int(131072))
   563  
   564  	werr := xsync.Go(func() error {
   565  		return wsjson.Write(ctx, c, exp)
   566  	})
   567  
   568  	var act interface{}
   569  	c.SetReadLimit(1 << 30)
   570  	err := wsjson.Read(ctx, c, &act)
   571  	assert.Success(tb, err)
   572  	assert.Equal(tb, "read msg", exp, act)
   573  
   574  	select {
   575  	case err := <-werr:
   576  		assert.Success(tb, err)
   577  	case <-ctx.Done():
   578  		tb.Fatal(ctx.Err())
   579  	}
   580  }
   581  
   582  func assertClose(tb testing.TB, c *websocket.Conn) {
   583  	tb.Helper()
   584  	err := c.Close(websocket.StatusNormalClosure, "")
   585  	assert.Success(tb, err)
   586  }
   587  
   588  func TestConcurrentClosePing(t *testing.T) {
   589  	t.Parallel()
   590  	for i := 0; i < 64; i++ {
   591  		func() {
   592  			c1, c2 := wstest.Pipe(nil, nil)
   593  			defer c1.CloseNow()
   594  			defer c2.CloseNow()
   595  			c1.CloseRead(context.Background())
   596  			c2.CloseRead(context.Background())
   597  			errc := xsync.Go(func() error {
   598  				for range time.Tick(time.Millisecond) {
   599  					err := c1.Ping(context.Background())
   600  					if err != nil {
   601  						return err
   602  					}
   603  				}
   604  				panic("unreachable")
   605  			})
   606  
   607  			time.Sleep(10 * time.Millisecond)
   608  			assert.Success(t, c1.Close(websocket.StatusNormalClosure, ""))
   609  			<-errc
   610  		}()
   611  	}
   612  }
   613  

View as plain text