...

Source file src/github.com/moby/spdystream/spdy_test.go

Documentation: github.com/moby/spdystream

     1  /*
     2     Copyright 2014-2021 Docker Inc.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package spdystream
    18  
    19  import (
    20  	"bufio"
    21  	"bytes"
    22  	"fmt"
    23  	"io"
    24  	"io/ioutil"
    25  	"net"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/moby/spdystream/spdy"
    33  )
    34  
    35  func TestSpdyStreams(t *testing.T) {
    36  	var wg sync.WaitGroup
    37  	server, listen, serverErr := runServer(&wg)
    38  	if serverErr != nil {
    39  		t.Fatalf("Error initializing server: %s", serverErr)
    40  	}
    41  
    42  	conn, dialErr := net.Dial("tcp", listen)
    43  	if dialErr != nil {
    44  		t.Fatalf("Error dialing server: %s", dialErr)
    45  	}
    46  
    47  	spdyConn, spdyErr := NewConnection(conn, false)
    48  	if spdyErr != nil {
    49  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
    50  	}
    51  	go spdyConn.Serve(NoOpStreamHandler)
    52  
    53  	authenticated = true
    54  	stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
    55  	if streamErr != nil {
    56  		t.Fatalf("Error creating stream: %s", streamErr)
    57  	}
    58  
    59  	waitErr := stream.Wait()
    60  	if waitErr != nil {
    61  		t.Fatalf("Error waiting for stream: %s", waitErr)
    62  	}
    63  
    64  	message := []byte("hello")
    65  	writeErr := stream.WriteData(message, false)
    66  	if writeErr != nil {
    67  		t.Fatalf("Error writing data")
    68  	}
    69  
    70  	buf := make([]byte, 10)
    71  	n, readErr := stream.Read(buf)
    72  	if readErr != nil {
    73  		t.Fatalf("Error reading data from stream: %s", readErr)
    74  	}
    75  	if n != 5 {
    76  		t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 5", n)
    77  	}
    78  	if !bytes.Equal(buf[:n], message) {
    79  		t.Fatalf("Did not receive expected message:\nActual: %s\nExpectd: %s", buf, message)
    80  	}
    81  
    82  	headers := http.Header{
    83  		"TestKey": []string{"TestVal"},
    84  	}
    85  	sendErr := stream.SendHeader(headers, false)
    86  	if sendErr != nil {
    87  		t.Fatalf("Error sending headers: %s", sendErr)
    88  	}
    89  	receiveHeaders, receiveErr := stream.ReceiveHeader()
    90  	if receiveErr != nil {
    91  		t.Fatalf("Error receiving headers: %s", receiveErr)
    92  	}
    93  	if len(receiveHeaders) != 1 {
    94  		t.Fatalf("Unexpected number of headers:\nActual: %d\nExpecting:%d", len(receiveHeaders), 1)
    95  	}
    96  	testVal := receiveHeaders.Get("TestKey")
    97  	if testVal != "TestVal" {
    98  		t.Fatalf("Wrong test value:\nActual: %q\nExpecting: %q", testVal, "TestVal")
    99  	}
   100  
   101  	writeErr = stream.WriteData(message, true)
   102  	if writeErr != nil {
   103  		t.Fatalf("Error writing data")
   104  	}
   105  
   106  	smallBuf := make([]byte, 3)
   107  	n, readErr = stream.Read(smallBuf)
   108  	if readErr != nil {
   109  		t.Fatalf("Error reading data from stream: %s", readErr)
   110  	}
   111  	if n != 3 {
   112  		t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 3", n)
   113  	}
   114  	if !bytes.Equal(smallBuf[:n], []byte("hel")) {
   115  		t.Fatalf("Did not receive expected message:\nActual: %s\nExpectd: %s", smallBuf[:n], message)
   116  	}
   117  	n, readErr = stream.Read(smallBuf)
   118  	if readErr != nil {
   119  		t.Fatalf("Error reading data from stream: %s", readErr)
   120  	}
   121  	if n != 2 {
   122  		t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 2", n)
   123  	}
   124  	if !bytes.Equal(smallBuf[:n], []byte("lo")) {
   125  		t.Fatalf("Did not receive expected message:\nActual: %s\nExpected: lo", smallBuf[:n])
   126  	}
   127  
   128  	n, readErr = stream.Read(buf)
   129  	if readErr != io.EOF {
   130  		t.Fatalf("Expected EOF reading from finished stream, read %d bytes", n)
   131  	}
   132  
   133  	// Closing again should return error since stream is already closed
   134  	streamCloseErr := stream.Close()
   135  	if streamCloseErr == nil {
   136  		t.Fatalf("No error closing finished stream")
   137  	}
   138  	if streamCloseErr != ErrWriteClosedStream {
   139  		t.Fatalf("Unexpected error closing stream: %s", streamCloseErr)
   140  	}
   141  
   142  	streamResetErr := stream.Reset()
   143  	if streamResetErr != nil {
   144  		t.Fatalf("Error reseting stream: %s", streamResetErr)
   145  	}
   146  
   147  	authenticated = false
   148  	badStream, badStreamErr := spdyConn.CreateStream(http.Header{}, nil, false)
   149  	if badStreamErr != nil {
   150  		t.Fatalf("Error creating stream: %s", badStreamErr)
   151  	}
   152  
   153  	waitErr = badStream.Wait()
   154  	if waitErr == nil {
   155  		t.Fatalf("Did not receive error creating stream")
   156  	}
   157  	if waitErr != ErrReset {
   158  		t.Fatalf("Unexpected error creating stream: %s", waitErr)
   159  	}
   160  	streamCloseErr = badStream.Close()
   161  	if streamCloseErr == nil {
   162  		t.Fatalf("No error closing bad stream")
   163  	}
   164  
   165  	spdyCloseErr := spdyConn.Close()
   166  	if spdyCloseErr != nil {
   167  		t.Fatalf("Error closing spdy connection: %s", spdyCloseErr)
   168  	}
   169  
   170  	closeErr := server.Close()
   171  	if closeErr != nil {
   172  		t.Fatalf("Error shutting down server: %s", closeErr)
   173  	}
   174  	wg.Wait()
   175  }
   176  
   177  func TestPing(t *testing.T) {
   178  	var wg sync.WaitGroup
   179  	server, listen, serverErr := runServer(&wg)
   180  	if serverErr != nil {
   181  		t.Fatalf("Error initializing server: %s", serverErr)
   182  	}
   183  
   184  	conn, dialErr := net.Dial("tcp", listen)
   185  	if dialErr != nil {
   186  		t.Fatalf("Error dialing server: %s", dialErr)
   187  	}
   188  
   189  	spdyConn, spdyErr := NewConnection(conn, false)
   190  	if spdyErr != nil {
   191  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   192  	}
   193  	go spdyConn.Serve(NoOpStreamHandler)
   194  
   195  	pingTime, pingErr := spdyConn.Ping()
   196  	if pingErr != nil {
   197  		t.Fatalf("Error pinging server: %s", pingErr)
   198  	}
   199  	if pingTime == time.Duration(0) {
   200  		t.Fatalf("Expecting non-zero ping time")
   201  	}
   202  
   203  	closeErr := server.Close()
   204  	if closeErr != nil {
   205  		t.Fatalf("Error shutting down server: %s", closeErr)
   206  	}
   207  	wg.Wait()
   208  }
   209  
   210  func TestHalfClose(t *testing.T) {
   211  	var wg sync.WaitGroup
   212  	server, listen, serverErr := runServer(&wg)
   213  	if serverErr != nil {
   214  		t.Fatalf("Error initializing server: %s", serverErr)
   215  	}
   216  
   217  	conn, dialErr := net.Dial("tcp", listen)
   218  	if dialErr != nil {
   219  		t.Fatalf("Error dialing server: %s", dialErr)
   220  	}
   221  
   222  	spdyConn, spdyErr := NewConnection(conn, false)
   223  	if spdyErr != nil {
   224  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   225  	}
   226  	go spdyConn.Serve(NoOpStreamHandler)
   227  
   228  	authenticated = true
   229  	stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
   230  	if streamErr != nil {
   231  		t.Fatalf("Error creating stream: %s", streamErr)
   232  	}
   233  
   234  	waitErr := stream.Wait()
   235  	if waitErr != nil {
   236  		t.Fatalf("Error waiting for stream: %s", waitErr)
   237  	}
   238  
   239  	message := []byte("hello and will read after close")
   240  	writeErr := stream.WriteData(message, false)
   241  	if writeErr != nil {
   242  		t.Fatalf("Error writing data")
   243  	}
   244  
   245  	streamCloseErr := stream.Close()
   246  	if streamCloseErr != nil {
   247  		t.Fatalf("Error closing stream: %s", streamCloseErr)
   248  	}
   249  
   250  	buf := make([]byte, 40)
   251  	n, readErr := stream.Read(buf)
   252  	if readErr != nil {
   253  		t.Fatalf("Error reading data from stream: %s", readErr)
   254  	}
   255  	if n != 31 {
   256  		t.Fatalf("Unexpected number of bytes read:\nActual: %d\nExpected: 5", n)
   257  	}
   258  	if !bytes.Equal(buf[:n], message) {
   259  		t.Fatalf("Did not receive expected message:\nActual: %s\nExpectd: %s", buf, message)
   260  	}
   261  
   262  	spdyCloseErr := spdyConn.Close()
   263  	if spdyCloseErr != nil {
   264  		t.Fatalf("Error closing spdy connection: %s", spdyCloseErr)
   265  	}
   266  
   267  	closeErr := server.Close()
   268  	if closeErr != nil {
   269  		t.Fatalf("Error shutting down server: %s", closeErr)
   270  	}
   271  	wg.Wait()
   272  }
   273  
   274  func TestUnexpectedRemoteConnectionClosed(t *testing.T) {
   275  	tt := []struct {
   276  		closeReceiver bool
   277  		closeSender   bool
   278  	}{
   279  		{closeReceiver: true, closeSender: false},
   280  		{closeReceiver: false, closeSender: true},
   281  		{closeReceiver: false, closeSender: false},
   282  	}
   283  	for tix, tc := range tt {
   284  		listener, listenErr := net.Listen("tcp", "localhost:0")
   285  		if listenErr != nil {
   286  			t.Fatalf("Error listening: %v", listenErr)
   287  		}
   288  
   289  		var serverConn net.Conn
   290  		var connErr error
   291  		go func() {
   292  			serverConn, connErr = listener.Accept()
   293  			if connErr != nil {
   294  				t.Errorf("Error accepting: %v", connErr)
   295  			}
   296  
   297  			serverSpdyConn, _ := NewConnection(serverConn, true)
   298  			go serverSpdyConn.Serve(func(stream *Stream) {
   299  				stream.SendReply(http.Header{}, tc.closeSender)
   300  			})
   301  		}()
   302  
   303  		conn, dialErr := net.Dial("tcp", listener.Addr().String())
   304  		if dialErr != nil {
   305  			t.Fatalf("Error dialing server: %s", dialErr)
   306  		}
   307  
   308  		spdyConn, spdyErr := NewConnection(conn, false)
   309  		if spdyErr != nil {
   310  			t.Fatalf("Error creating spdy connection: %s", spdyErr)
   311  		}
   312  		go spdyConn.Serve(NoOpStreamHandler)
   313  
   314  		authenticated = true
   315  		stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
   316  		if streamErr != nil {
   317  			t.Fatalf("Error creating stream: %s", streamErr)
   318  		}
   319  
   320  		waitErr := stream.Wait()
   321  		if waitErr != nil {
   322  			t.Fatalf("Error waiting for stream: %s", waitErr)
   323  		}
   324  
   325  		if tc.closeReceiver {
   326  			// make stream half closed, receive only
   327  			stream.Close()
   328  		}
   329  
   330  		streamch := make(chan error, 1)
   331  		go func() {
   332  			b := make([]byte, 1)
   333  			_, err := stream.Read(b)
   334  			streamch <- err
   335  		}()
   336  
   337  		closeErr := serverConn.Close()
   338  		if closeErr != nil {
   339  			t.Fatalf("Error shutting down server: %s", closeErr)
   340  		}
   341  
   342  		e := <-streamch
   343  		if e == nil || e != io.EOF {
   344  			t.Fatalf("(%d) Expected to get an EOF stream error", tix)
   345  		}
   346  
   347  		closeErr = conn.Close()
   348  		if closeErr != nil {
   349  			t.Fatalf("Error closing client connection: %s", closeErr)
   350  		}
   351  
   352  		listenErr = listener.Close()
   353  		if listenErr != nil {
   354  			t.Fatalf("Error closing listener: %s", listenErr)
   355  		}
   356  	}
   357  }
   358  
   359  func TestCloseNotification(t *testing.T) {
   360  	listener, listenErr := net.Listen("tcp", "localhost:0")
   361  	if listenErr != nil {
   362  		t.Fatalf("Error listening: %v", listenErr)
   363  	}
   364  	listen := listener.Addr().String()
   365  
   366  	serverConnChan := make(chan net.Conn)
   367  	go func() {
   368  		serverConn, err := listener.Accept()
   369  		if err != nil {
   370  			t.Errorf("Error accepting: %v", err)
   371  		}
   372  
   373  		serverSpdyConn, err := NewConnection(serverConn, true)
   374  		if err != nil {
   375  			t.Errorf("Error creating server connection: %v", err)
   376  		}
   377  		go serverSpdyConn.Serve(NoOpStreamHandler)
   378  		<-serverSpdyConn.CloseChan()
   379  		serverConnChan <- serverConn
   380  	}()
   381  
   382  	conn, dialErr := net.Dial("tcp", listen)
   383  	if dialErr != nil {
   384  		t.Fatalf("Error dialing server: %s", dialErr)
   385  	}
   386  
   387  	spdyConn, spdyErr := NewConnection(conn, false)
   388  	if spdyErr != nil {
   389  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   390  	}
   391  	go spdyConn.Serve(NoOpStreamHandler)
   392  
   393  	// close client conn
   394  	err := conn.Close()
   395  	if err != nil {
   396  		t.Fatalf("Error closing client connection: %v", err)
   397  	}
   398  
   399  	serverConn := <-serverConnChan
   400  
   401  	err = serverConn.Close()
   402  	if err != nil {
   403  		t.Fatalf("Error closing serverConn: %v", err)
   404  	}
   405  
   406  	listenErr = listener.Close()
   407  	if listenErr != nil {
   408  		t.Fatalf("Error closing listener: %s", listenErr)
   409  	}
   410  }
   411  
   412  func TestIdleShutdownRace(t *testing.T) {
   413  	var wg sync.WaitGroup
   414  	server, listen, serverErr := runServer(&wg)
   415  	if serverErr != nil {
   416  		t.Fatalf("Error initializing server: %s", serverErr)
   417  	}
   418  
   419  	conn, dialErr := net.Dial("tcp", listen)
   420  	if dialErr != nil {
   421  		t.Fatalf("Error dialing server: %s", dialErr)
   422  	}
   423  
   424  	spdyConn, spdyErr := NewConnection(conn, false)
   425  	if spdyErr != nil {
   426  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   427  	}
   428  	go spdyConn.Serve(NoOpStreamHandler)
   429  
   430  	authenticated = true
   431  	stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
   432  	if err != nil {
   433  		t.Fatalf("Error creating stream: %v", err)
   434  	}
   435  
   436  	spdyConn.SetIdleTimeout(5 * time.Millisecond)
   437  	go func() {
   438  		time.Sleep(5 * time.Millisecond)
   439  		stream.Reset()
   440  	}()
   441  
   442  	select {
   443  	case <-spdyConn.CloseChan():
   444  	case <-time.After(20 * time.Millisecond):
   445  		t.Fatal("Timed out waiting for idle connection closure")
   446  	}
   447  
   448  	closeErr := server.Close()
   449  	if closeErr != nil {
   450  		t.Fatalf("Error shutting down server: %s", closeErr)
   451  	}
   452  	wg.Wait()
   453  }
   454  
   455  func TestIdleNoTimeoutSet(t *testing.T) {
   456  	var wg sync.WaitGroup
   457  	server, listen, serverErr := runServer(&wg)
   458  	if serverErr != nil {
   459  		t.Fatalf("Error initializing server: %s", serverErr)
   460  	}
   461  
   462  	conn, dialErr := net.Dial("tcp", listen)
   463  	if dialErr != nil {
   464  		t.Fatalf("Error dialing server: %s", dialErr)
   465  	}
   466  
   467  	spdyConn, spdyErr := NewConnection(conn, false)
   468  	if spdyErr != nil {
   469  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   470  	}
   471  	go spdyConn.Serve(NoOpStreamHandler)
   472  
   473  	select {
   474  	case <-spdyConn.CloseChan():
   475  		t.Fatal("Unexpected connection closure")
   476  	case <-time.After(10 * time.Millisecond):
   477  	}
   478  
   479  	closeErr := server.Close()
   480  	if closeErr != nil {
   481  		t.Fatalf("Error shutting down server: %s", closeErr)
   482  	}
   483  	wg.Wait()
   484  }
   485  
   486  func TestIdleClearTimeout(t *testing.T) {
   487  	var wg sync.WaitGroup
   488  	server, listen, serverErr := runServer(&wg)
   489  	if serverErr != nil {
   490  		t.Fatalf("Error initializing server: %s", serverErr)
   491  	}
   492  
   493  	conn, dialErr := net.Dial("tcp", listen)
   494  	if dialErr != nil {
   495  		t.Fatalf("Error dialing server: %s", dialErr)
   496  	}
   497  
   498  	spdyConn, spdyErr := NewConnection(conn, false)
   499  	if spdyErr != nil {
   500  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   501  	}
   502  	go spdyConn.Serve(NoOpStreamHandler)
   503  
   504  	spdyConn.SetIdleTimeout(10 * time.Millisecond)
   505  	spdyConn.SetIdleTimeout(0)
   506  	select {
   507  	case <-spdyConn.CloseChan():
   508  		t.Fatal("Unexpected connection closure")
   509  	case <-time.After(20 * time.Millisecond):
   510  	}
   511  
   512  	closeErr := server.Close()
   513  	if closeErr != nil {
   514  		t.Fatalf("Error shutting down server: %s", closeErr)
   515  	}
   516  	wg.Wait()
   517  }
   518  
   519  func TestIdleNoData(t *testing.T) {
   520  	var wg sync.WaitGroup
   521  	server, listen, serverErr := runServer(&wg)
   522  	if serverErr != nil {
   523  		t.Fatalf("Error initializing server: %s", serverErr)
   524  	}
   525  
   526  	conn, dialErr := net.Dial("tcp", listen)
   527  	if dialErr != nil {
   528  		t.Fatalf("Error dialing server: %s", dialErr)
   529  	}
   530  
   531  	spdyConn, spdyErr := NewConnection(conn, false)
   532  	if spdyErr != nil {
   533  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   534  	}
   535  	go spdyConn.Serve(NoOpStreamHandler)
   536  
   537  	spdyConn.SetIdleTimeout(10 * time.Millisecond)
   538  	<-spdyConn.CloseChan()
   539  
   540  	closeErr := server.Close()
   541  	if closeErr != nil {
   542  		t.Fatalf("Error shutting down server: %s", closeErr)
   543  	}
   544  	wg.Wait()
   545  }
   546  
   547  func TestIdleWithData(t *testing.T) {
   548  	var wg sync.WaitGroup
   549  	server, listen, serverErr := runServer(&wg)
   550  	if serverErr != nil {
   551  		t.Fatalf("Error initializing server: %s", serverErr)
   552  	}
   553  
   554  	conn, dialErr := net.Dial("tcp", listen)
   555  	if dialErr != nil {
   556  		t.Fatalf("Error dialing server: %s", dialErr)
   557  	}
   558  
   559  	spdyConn, spdyErr := NewConnection(conn, false)
   560  	if spdyErr != nil {
   561  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   562  	}
   563  	go spdyConn.Serve(NoOpStreamHandler)
   564  
   565  	spdyConn.SetIdleTimeout(25 * time.Millisecond)
   566  
   567  	authenticated = true
   568  	stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
   569  	if err != nil {
   570  		t.Fatalf("Error creating stream: %v", err)
   571  	}
   572  
   573  	writeCh := make(chan struct{})
   574  
   575  	go func() {
   576  		b := []byte{1, 2, 3, 4, 5}
   577  		for i := 0; i < 10; i++ {
   578  			_, err = stream.Write(b)
   579  			if err != nil {
   580  				t.Errorf("Error writing to stream: %v", err)
   581  			}
   582  			time.Sleep(10 * time.Millisecond)
   583  		}
   584  		close(writeCh)
   585  	}()
   586  
   587  	writesFinished := false
   588  
   589  Loop:
   590  	for {
   591  		select {
   592  		case <-writeCh:
   593  			writesFinished = true
   594  		case <-spdyConn.CloseChan():
   595  			if !writesFinished {
   596  				t.Fatal("Connection closed before all writes finished")
   597  			}
   598  			break Loop
   599  		}
   600  	}
   601  
   602  	closeErr := server.Close()
   603  	if closeErr != nil {
   604  		t.Fatalf("Error shutting down server: %s", closeErr)
   605  	}
   606  	wg.Wait()
   607  }
   608  
   609  func TestIdleRace(t *testing.T) {
   610  	var wg sync.WaitGroup
   611  	server, listen, serverErr := runServer(&wg)
   612  	if serverErr != nil {
   613  		t.Fatalf("Error initializing server: %s", serverErr)
   614  	}
   615  
   616  	conn, dialErr := net.Dial("tcp", listen)
   617  	if dialErr != nil {
   618  		t.Fatalf("Error dialing server: %s", dialErr)
   619  	}
   620  
   621  	spdyConn, spdyErr := NewConnection(conn, false)
   622  	if spdyErr != nil {
   623  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   624  	}
   625  	go spdyConn.Serve(NoOpStreamHandler)
   626  
   627  	spdyConn.SetIdleTimeout(10 * time.Millisecond)
   628  
   629  	authenticated = true
   630  
   631  	for i := 0; i < 10; i++ {
   632  		_, err := spdyConn.CreateStream(http.Header{}, nil, false)
   633  		if err != nil {
   634  			t.Fatalf("Error creating stream: %v", err)
   635  		}
   636  	}
   637  
   638  	<-spdyConn.CloseChan()
   639  
   640  	closeErr := server.Close()
   641  	if closeErr != nil {
   642  		t.Fatalf("Error shutting down server: %s", closeErr)
   643  	}
   644  	wg.Wait()
   645  }
   646  
   647  func TestHalfClosedIdleTimeout(t *testing.T) {
   648  	listener, listenErr := net.Listen("tcp", "localhost:0")
   649  	if listenErr != nil {
   650  		t.Fatalf("Error listening: %v", listenErr)
   651  	}
   652  	listen := listener.Addr().String()
   653  
   654  	go func() {
   655  		serverConn, err := listener.Accept()
   656  		if err != nil {
   657  			t.Errorf("Error accepting: %v", err)
   658  		}
   659  
   660  		serverSpdyConn, err := NewConnection(serverConn, true)
   661  		if err != nil {
   662  			t.Errorf("Error creating server connection: %v", err)
   663  		}
   664  		go serverSpdyConn.Serve(func(s *Stream) {
   665  			s.SendReply(http.Header{}, true)
   666  		})
   667  		serverSpdyConn.SetIdleTimeout(10 * time.Millisecond)
   668  	}()
   669  
   670  	conn, dialErr := net.Dial("tcp", listen)
   671  	if dialErr != nil {
   672  		t.Fatalf("Error dialing server: %s", dialErr)
   673  	}
   674  
   675  	spdyConn, spdyErr := NewConnection(conn, false)
   676  	if spdyErr != nil {
   677  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   678  	}
   679  	go spdyConn.Serve(NoOpStreamHandler)
   680  
   681  	stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
   682  	if err != nil {
   683  		t.Fatalf("Error creating stream: %v", err)
   684  	}
   685  
   686  	time.Sleep(20 * time.Millisecond)
   687  
   688  	stream.Reset()
   689  
   690  	err = spdyConn.Close()
   691  	if err != nil {
   692  		t.Fatalf("Error closing client spdy conn: %v", err)
   693  	}
   694  }
   695  
   696  func TestStreamReset(t *testing.T) {
   697  	var wg sync.WaitGroup
   698  	server, listen, serverErr := runServer(&wg)
   699  	if serverErr != nil {
   700  		t.Fatalf("Error initializing server: %s", serverErr)
   701  	}
   702  
   703  	conn, dialErr := net.Dial("tcp", listen)
   704  	if dialErr != nil {
   705  		t.Fatalf("Error dialing server: %s", dialErr)
   706  	}
   707  
   708  	spdyConn, spdyErr := NewConnection(conn, false)
   709  	if spdyErr != nil {
   710  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   711  	}
   712  	go spdyConn.Serve(NoOpStreamHandler)
   713  
   714  	authenticated = true
   715  	stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
   716  	if streamErr != nil {
   717  		t.Fatalf("Error creating stream: %s", streamErr)
   718  	}
   719  
   720  	buf := []byte("dskjahfkdusahfkdsahfkdsafdkas")
   721  	for i := 0; i < 10; i++ {
   722  		if _, err := stream.Write(buf); err != nil {
   723  			t.Fatalf("Error writing to stream: %s", err)
   724  		}
   725  	}
   726  	for i := 0; i < 10; i++ {
   727  		if _, err := stream.Read(buf); err != nil {
   728  			t.Fatalf("Error reading from stream: %s", err)
   729  		}
   730  	}
   731  
   732  	// fmt.Printf("Resetting...\n")
   733  	if err := stream.Reset(); err != nil {
   734  		t.Fatalf("Error reseting stream: %s", err)
   735  	}
   736  
   737  	closeErr := server.Close()
   738  	if closeErr != nil {
   739  		t.Fatalf("Error shutting down server: %s", closeErr)
   740  	}
   741  	wg.Wait()
   742  }
   743  
   744  func TestStreamResetWithDataRemaining(t *testing.T) {
   745  	var wg sync.WaitGroup
   746  	server, listen, serverErr := runServer(&wg)
   747  	if serverErr != nil {
   748  		t.Fatalf("Error initializing server: %s", serverErr)
   749  	}
   750  
   751  	conn, dialErr := net.Dial("tcp", listen)
   752  	if dialErr != nil {
   753  		t.Fatalf("Error dialing server: %s", dialErr)
   754  	}
   755  
   756  	spdyConn, spdyErr := NewConnection(conn, false)
   757  	if spdyErr != nil {
   758  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
   759  	}
   760  	go spdyConn.Serve(NoOpStreamHandler)
   761  
   762  	authenticated = true
   763  	stream, streamErr := spdyConn.CreateStream(http.Header{}, nil, false)
   764  	if streamErr != nil {
   765  		t.Fatalf("Error creating stream: %s", streamErr)
   766  	}
   767  
   768  	buf := []byte("dskjahfkdusahfkdsahfkdsafdkas")
   769  	for i := 0; i < 10; i++ {
   770  		if _, err := stream.Write(buf); err != nil {
   771  			t.Fatalf("Error writing to stream: %s", err)
   772  		}
   773  	}
   774  
   775  	// read a bit to make sure a goroutine gets to <-dataChan
   776  	if _, err := stream.Read(buf); err != nil {
   777  		t.Fatalf("Error reading from stream: %s", err)
   778  	}
   779  
   780  	// fmt.Printf("Resetting...\n")
   781  	if err := stream.Reset(); err != nil {
   782  		t.Fatalf("Error reseting stream: %s", err)
   783  	}
   784  
   785  	closeErr := server.Close()
   786  	if closeErr != nil {
   787  		t.Fatalf("Error shutting down server: %s", closeErr)
   788  	}
   789  	wg.Wait()
   790  }
   791  
   792  type roundTripper struct {
   793  	conn net.Conn
   794  }
   795  
   796  func (s *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   797  	r := *req
   798  	req = &r
   799  
   800  	conn, err := net.Dial("tcp", req.URL.Host)
   801  	if err != nil {
   802  		return nil, err
   803  	}
   804  
   805  	err = req.Write(conn)
   806  	if err != nil {
   807  		return nil, err
   808  	}
   809  
   810  	resp, err := http.ReadResponse(bufio.NewReader(conn), req)
   811  	if err != nil {
   812  		return nil, err
   813  	}
   814  
   815  	s.conn = conn
   816  
   817  	return resp, nil
   818  }
   819  
   820  // see https://github.com/GoogleCloudPlatform/kubernetes/issues/4882
   821  func TestFramingAfterRemoteConnectionClosed(t *testing.T) {
   822  	server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   823  		streamCh := make(chan *Stream)
   824  
   825  		w.WriteHeader(http.StatusSwitchingProtocols)
   826  
   827  		netconn, _, _ := w.(http.Hijacker).Hijack()
   828  		conn, _ := NewConnection(netconn, true)
   829  		go conn.Serve(func(s *Stream) {
   830  			s.SendReply(http.Header{}, false)
   831  			streamCh <- s
   832  		})
   833  
   834  		stream := <-streamCh
   835  		io.Copy(stream, stream)
   836  
   837  		closeChan := make(chan struct{})
   838  		go func() {
   839  			stream.Reset()
   840  			conn.Close()
   841  			close(closeChan)
   842  		}()
   843  
   844  		<-closeChan
   845  	}))
   846  
   847  	server.Start()
   848  	defer server.Close()
   849  
   850  	req, err := http.NewRequest("GET", server.URL, nil)
   851  	if err != nil {
   852  		t.Fatalf("Error creating request: %s", err)
   853  	}
   854  
   855  	rt := &roundTripper{}
   856  	client := &http.Client{Transport: rt}
   857  
   858  	_, err = client.Do(req)
   859  	if err != nil {
   860  		t.Fatalf("unexpected error from client.Do: %s", err)
   861  	}
   862  
   863  	conn, err := NewConnection(rt.conn, false)
   864  	if err != nil {
   865  		t.Fatal("Error creating spdy connection:", err)
   866  	}
   867  	go conn.Serve(NoOpStreamHandler)
   868  
   869  	stream, err := conn.CreateStream(http.Header{}, nil, false)
   870  	if err != nil {
   871  		t.Fatalf("error creating client stream: %s", err)
   872  	}
   873  
   874  	n, err := stream.Write([]byte("hello"))
   875  	if err != nil {
   876  		t.Fatalf("error writing to stream: %s", err)
   877  	}
   878  	if n != 5 {
   879  		t.Fatalf("Expected to write 5 bytes, but actually wrote %d", n)
   880  	}
   881  
   882  	b := make([]byte, 5)
   883  	n, err = stream.Read(b)
   884  	if err != nil {
   885  		t.Fatalf("error reading from stream: %s", err)
   886  	}
   887  	if n != 5 {
   888  		t.Fatalf("Expected to read 5 bytes, but actually read %d", n)
   889  	}
   890  	if e, a := "hello", string(b[0:n]); e != a {
   891  		t.Fatalf("expected '%s', got '%s'", e, a)
   892  	}
   893  
   894  	stream.Reset()
   895  	conn.Close()
   896  }
   897  
   898  func TestGoAwayRace(t *testing.T) {
   899  	var done sync.WaitGroup
   900  	listener, err := net.Listen("tcp", "localhost:0")
   901  	if err != nil {
   902  		t.Fatalf("Error listening: %v", err)
   903  	}
   904  	listen := listener.Addr().String()
   905  
   906  	processDataFrame := make(chan struct{})
   907  	serverClosed := make(chan struct{})
   908  
   909  	done.Add(1)
   910  	go func() {
   911  		defer done.Done()
   912  		serverConn, err := listener.Accept()
   913  		if err != nil {
   914  			t.Errorf("Error accepting: %v", err)
   915  		}
   916  
   917  		serverSpdyConn, err := NewConnection(serverConn, true)
   918  		if err != nil {
   919  			t.Errorf("Error creating server connection: %v", err)
   920  		}
   921  		go func() {
   922  			<-serverSpdyConn.CloseChan()
   923  			close(serverClosed)
   924  		}()
   925  
   926  		// force the data frame handler to sleep before delivering the frame
   927  		serverSpdyConn.dataFrameHandler = func(frame *spdy.DataFrame) error {
   928  			<-processDataFrame
   929  			return serverSpdyConn.handleDataFrame(frame)
   930  		}
   931  
   932  		streamCh := make(chan *Stream)
   933  		go serverSpdyConn.Serve(func(s *Stream) {
   934  			s.SendReply(http.Header{}, false)
   935  			streamCh <- s
   936  		})
   937  
   938  		stream, ok := <-streamCh
   939  		if !ok {
   940  			t.Errorf("didn't get a stream")
   941  		}
   942  		stream.Close()
   943  		data, err := ioutil.ReadAll(stream)
   944  		if err != nil {
   945  			t.Error(err)
   946  		}
   947  		if e, a := "hello1hello2hello3hello4hello5", string(data); e != a {
   948  			t.Errorf("Expected %q, got %q", e, a)
   949  		}
   950  	}()
   951  
   952  	dialConn, err := net.Dial("tcp", listen)
   953  	if err != nil {
   954  		t.Fatalf("Error dialing server: %s", err)
   955  	}
   956  	conn, err := NewConnection(dialConn, false)
   957  	if err != nil {
   958  		t.Fatalf("Error creating client connectin: %v", err)
   959  	}
   960  	go conn.Serve(NoOpStreamHandler)
   961  
   962  	stream, err := conn.CreateStream(http.Header{}, nil, false)
   963  	if err != nil {
   964  		t.Fatalf("error creating client stream: %s", err)
   965  	}
   966  	if err := stream.Wait(); err != nil {
   967  		t.Fatalf("error waiting for stream creation: %v", err)
   968  	}
   969  
   970  	fmt.Fprint(stream, "hello1")
   971  	fmt.Fprint(stream, "hello2")
   972  	fmt.Fprint(stream, "hello3")
   973  	fmt.Fprint(stream, "hello4")
   974  	fmt.Fprint(stream, "hello5")
   975  
   976  	stream.Close()
   977  	conn.Close()
   978  
   979  	// wait for the server to get the go away frame
   980  	<-serverClosed
   981  
   982  	// allow the data frames to be delivered to the server's stream
   983  	close(processDataFrame)
   984  
   985  	done.Wait()
   986  }
   987  
   988  func TestSetIdleTimeoutAfterRemoteConnectionClosed(t *testing.T) {
   989  	listener, err := net.Listen("tcp", "localhost:0")
   990  	if err != nil {
   991  		t.Fatalf("Error listening: %v", err)
   992  	}
   993  	listen := listener.Addr().String()
   994  
   995  	serverConns := make(chan *Connection, 1)
   996  	go func() {
   997  		conn, connErr := listener.Accept()
   998  		if connErr != nil {
   999  			t.Error(connErr)
  1000  		}
  1001  		serverSpdyConn, err := NewConnection(conn, true)
  1002  		if err != nil {
  1003  			t.Errorf("Error creating server connection: %v", err)
  1004  		}
  1005  		go serverSpdyConn.Serve(NoOpStreamHandler)
  1006  		serverConns <- serverSpdyConn
  1007  	}()
  1008  
  1009  	conn, dialErr := net.Dial("tcp", listen)
  1010  	if dialErr != nil {
  1011  		t.Fatalf("Error dialing server: %s", dialErr)
  1012  	}
  1013  
  1014  	spdyConn, spdyErr := NewConnection(conn, false)
  1015  	if spdyErr != nil {
  1016  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
  1017  	}
  1018  	go spdyConn.Serve(NoOpStreamHandler)
  1019  
  1020  	if err := spdyConn.Close(); err != nil {
  1021  		t.Fatal(err)
  1022  	}
  1023  
  1024  	serverConn := <-serverConns
  1025  	defer serverConn.Close()
  1026  	<-serverConn.closeChan
  1027  
  1028  	serverConn.SetIdleTimeout(10 * time.Second)
  1029  }
  1030  
  1031  func TestClientConnectionStopsServingAfterGoAway(t *testing.T) {
  1032  	listener, err := net.Listen("tcp", "localhost:0")
  1033  	if err != nil {
  1034  		t.Fatalf("Error listening: %v", err)
  1035  	}
  1036  	listen := listener.Addr().String()
  1037  
  1038  	serverConns := make(chan *Connection, 1)
  1039  	go func() {
  1040  		conn, connErr := listener.Accept()
  1041  		if connErr != nil {
  1042  			t.Error(connErr)
  1043  		}
  1044  		serverSpdyConn, err := NewConnection(conn, true)
  1045  		if err != nil {
  1046  			t.Errorf("Error creating server connection: %v", err)
  1047  		}
  1048  		go serverSpdyConn.Serve(NoOpStreamHandler)
  1049  		serverConns <- serverSpdyConn
  1050  	}()
  1051  
  1052  	conn, dialErr := net.Dial("tcp", listen)
  1053  	if dialErr != nil {
  1054  		t.Fatalf("Error dialing server: %s", dialErr)
  1055  	}
  1056  
  1057  	spdyConn, spdyErr := NewConnection(conn, false)
  1058  	if spdyErr != nil {
  1059  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
  1060  	}
  1061  	go spdyConn.Serve(NoOpStreamHandler)
  1062  
  1063  	stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
  1064  	if err != nil {
  1065  		t.Fatalf("Error creating stream: %v", err)
  1066  	}
  1067  	if err := stream.WaitTimeout(30 * time.Second); err != nil {
  1068  		t.Fatalf("Timed out waiting for stream: %v", err)
  1069  	}
  1070  
  1071  	readChan := make(chan struct{})
  1072  	go func() {
  1073  		_, err := ioutil.ReadAll(stream)
  1074  		if err != nil {
  1075  			t.Errorf("Error reading stream: %v", err)
  1076  		}
  1077  		close(readChan)
  1078  	}()
  1079  
  1080  	serverConn := <-serverConns
  1081  	serverConn.Close()
  1082  
  1083  	// make sure the client conn breaks out of the main loop in Serve()
  1084  	<-spdyConn.closeChan
  1085  	// make sure the remote channels are closed and the stream read is unblocked
  1086  	<-readChan
  1087  }
  1088  
  1089  func TestStreamReadUnblocksAfterCloseThenReset(t *testing.T) {
  1090  	listener, err := net.Listen("tcp", "localhost:0")
  1091  	if err != nil {
  1092  		t.Fatalf("Error listening: %v", err)
  1093  	}
  1094  	listen := listener.Addr().String()
  1095  
  1096  	serverConns := make(chan *Connection, 1)
  1097  	go func() {
  1098  		conn, connErr := listener.Accept()
  1099  		if connErr != nil {
  1100  			t.Error(connErr)
  1101  		}
  1102  		serverSpdyConn, err := NewConnection(conn, true)
  1103  		if err != nil {
  1104  			t.Errorf("Error creating server connection: %v", err)
  1105  		}
  1106  		go serverSpdyConn.Serve(NoOpStreamHandler)
  1107  		serverConns <- serverSpdyConn
  1108  	}()
  1109  
  1110  	conn, dialErr := net.Dial("tcp", listen)
  1111  	if dialErr != nil {
  1112  		t.Fatalf("Error dialing server: %s", dialErr)
  1113  	}
  1114  
  1115  	spdyConn, spdyErr := NewConnection(conn, false)
  1116  	if spdyErr != nil {
  1117  		t.Fatalf("Error creating spdy connection: %s", spdyErr)
  1118  	}
  1119  	go spdyConn.Serve(NoOpStreamHandler)
  1120  
  1121  	stream, err := spdyConn.CreateStream(http.Header{}, nil, false)
  1122  	if err != nil {
  1123  		t.Fatalf("Error creating stream: %v", err)
  1124  	}
  1125  	if err := stream.WaitTimeout(30 * time.Second); err != nil {
  1126  		t.Fatalf("Timed out waiting for stream: %v", err)
  1127  	}
  1128  
  1129  	readChan := make(chan struct{})
  1130  	go func() {
  1131  		_, err := ioutil.ReadAll(stream)
  1132  		if err != nil {
  1133  			t.Errorf("Error reading stream: %v", err)
  1134  		}
  1135  		close(readChan)
  1136  	}()
  1137  
  1138  	serverConn := <-serverConns
  1139  	defer serverConn.Close()
  1140  
  1141  	if err := stream.Close(); err != nil {
  1142  		t.Fatal(err)
  1143  	}
  1144  	if err := stream.Reset(); err != nil {
  1145  		t.Fatal(err)
  1146  	}
  1147  
  1148  	// make sure close followed by reset unblocks stream.Read()
  1149  	select {
  1150  	case <-readChan:
  1151  	case <-time.After(10 * time.Second):
  1152  		t.Fatal("Timed out waiting for stream read to unblock")
  1153  	}
  1154  }
  1155  
  1156  var authenticated bool
  1157  
  1158  func authStreamHandler(stream *Stream) {
  1159  	if !authenticated {
  1160  		stream.Refuse()
  1161  		return
  1162  	}
  1163  	MirrorStreamHandler(stream)
  1164  }
  1165  
  1166  func runServer(wg *sync.WaitGroup) (io.Closer, string, error) {
  1167  	listener, listenErr := net.Listen("tcp", "localhost:0")
  1168  	if listenErr != nil {
  1169  		return nil, "", listenErr
  1170  	}
  1171  	wg.Add(1)
  1172  	go func() {
  1173  		for {
  1174  			conn, connErr := listener.Accept()
  1175  			if connErr != nil {
  1176  				break
  1177  			}
  1178  
  1179  			spdyConn, _ := NewConnection(conn, true)
  1180  			go spdyConn.Serve(authStreamHandler)
  1181  
  1182  		}
  1183  		wg.Done()
  1184  	}()
  1185  	return listener, listener.Addr().String(), nil
  1186  }
  1187  

View as plain text