...

Source file src/golang.org/x/net/http2/clientconn_test.go

Documentation: golang.org/x/net/http2

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Infrastructure for testing ClientConn.RoundTrip.
     6  // Put actual tests in transport_test.go.
     7  
     8  package http2
     9  
    10  import (
    11  	"bytes"
    12  	"context"
    13  	"crypto/tls"
    14  	"fmt"
    15  	"io"
    16  	"net/http"
    17  	"reflect"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  
    22  	"golang.org/x/net/http2/hpack"
    23  )
    24  
    25  // TestTestClientConn demonstrates usage of testClientConn.
    26  func TestTestClientConn(t *testing.T) {
    27  	// newTestClientConn creates a *ClientConn and surrounding test infrastructure.
    28  	tc := newTestClientConn(t)
    29  
    30  	// tc.greet reads the client's initial SETTINGS and WINDOW_UPDATE frames,
    31  	// and sends a SETTINGS frame to the client.
    32  	//
    33  	// Additional settings may be provided as optional parameters to greet.
    34  	tc.greet()
    35  
    36  	// Request bodies must either be constant (bytes.Buffer, strings.Reader)
    37  	// or created with newRequestBody.
    38  	body := tc.newRequestBody()
    39  	body.writeBytes(10)         // 10 arbitrary bytes...
    40  	body.closeWithError(io.EOF) // ...followed by EOF.
    41  
    42  	// tc.roundTrip calls RoundTrip, but does not wait for it to return.
    43  	// It returns a testRoundTrip.
    44  	req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
    45  	rt := tc.roundTrip(req)
    46  
    47  	// tc has a number of methods to check for expected frames sent.
    48  	// Here, we look for headers and the request body.
    49  	tc.wantHeaders(wantHeader{
    50  		streamID:  rt.streamID(),
    51  		endStream: false,
    52  		header: http.Header{
    53  			":authority": []string{"dummy.tld"},
    54  			":method":    []string{"PUT"},
    55  			":path":      []string{"/"},
    56  		},
    57  	})
    58  	// Expect 10 bytes of request body in DATA frames.
    59  	tc.wantData(wantData{
    60  		streamID:  rt.streamID(),
    61  		endStream: true,
    62  		size:      10,
    63  		multiple:  true,
    64  	})
    65  
    66  	// tc.writeHeaders sends a HEADERS frame back to the client.
    67  	tc.writeHeaders(HeadersFrameParam{
    68  		StreamID:   rt.streamID(),
    69  		EndHeaders: true,
    70  		EndStream:  true,
    71  		BlockFragment: tc.makeHeaderBlockFragment(
    72  			":status", "200",
    73  		),
    74  	})
    75  
    76  	// Now that we've received headers, RoundTrip has finished.
    77  	// testRoundTrip has various methods to examine the response,
    78  	// or to fetch the response and/or error returned by RoundTrip
    79  	rt.wantStatus(200)
    80  	rt.wantBody(nil)
    81  }
    82  
    83  // A testClientConn allows testing ClientConn.RoundTrip against a fake server.
    84  //
    85  // A test using testClientConn consists of:
    86  //   - actions on the client (calling RoundTrip, making data available to Request.Body);
    87  //   - validation of frames sent by the client to the server; and
    88  //   - providing frames from the server to the client.
    89  //
    90  // testClientConn manages synchronization, so tests can generally be written as
    91  // a linear sequence of actions and validations without additional synchronization.
    92  type testClientConn struct {
    93  	t *testing.T
    94  
    95  	tr    *Transport
    96  	fr    *Framer
    97  	cc    *ClientConn
    98  	group *synctestGroup
    99  	testConnFramer
   100  
   101  	encbuf bytes.Buffer
   102  	enc    *hpack.Encoder
   103  
   104  	roundtrips []*testRoundTrip
   105  
   106  	netconn *synctestNetConn
   107  }
   108  
   109  func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
   110  	tc := &testClientConn{
   111  		t:     t,
   112  		tr:    cc.t,
   113  		cc:    cc,
   114  		group: cc.t.transportTestHooks.group.(*synctestGroup),
   115  	}
   116  
   117  	// srv is the side controlled by the test.
   118  	var srv *synctestNetConn
   119  	if cc.tconn == nil {
   120  		// If cc.tconn is nil, we're being called with a new conn created by the
   121  		// Transport's client pool. This path skips dialing the server, and we
   122  		// create a test connection pair here.
   123  		cc.tconn, srv = synctestNetPipe(tc.group)
   124  	} else {
   125  		// If cc.tconn is non-nil, we're in a test which provides a conn to the
   126  		// Transport via a TLSNextProto hook. Extract the test connection pair.
   127  		if tc, ok := cc.tconn.(*tls.Conn); ok {
   128  			// Unwrap any *tls.Conn to the underlying net.Conn,
   129  			// to avoid dealing with encryption in tests.
   130  			cc.tconn = tc.NetConn()
   131  		}
   132  		srv = cc.tconn.(*synctestNetConn).peer
   133  	}
   134  
   135  	srv.SetReadDeadline(tc.group.Now())
   136  	srv.autoWait = true
   137  	tc.netconn = srv
   138  	tc.enc = hpack.NewEncoder(&tc.encbuf)
   139  	tc.fr = NewFramer(srv, srv)
   140  	tc.testConnFramer = testConnFramer{
   141  		t:   t,
   142  		fr:  tc.fr,
   143  		dec: hpack.NewDecoder(initialHeaderTableSize, nil),
   144  	}
   145  	tc.fr.SetMaxReadFrameSize(10 << 20)
   146  	t.Cleanup(func() {
   147  		tc.closeWrite()
   148  	})
   149  
   150  	return tc
   151  }
   152  
   153  func (tc *testClientConn) readClientPreface() {
   154  	tc.t.Helper()
   155  	// Read the client's HTTP/2 preface, sent prior to any HTTP/2 frames.
   156  	buf := make([]byte, len(clientPreface))
   157  	if _, err := io.ReadFull(tc.netconn, buf); err != nil {
   158  		tc.t.Fatalf("reading preface: %v", err)
   159  	}
   160  	if !bytes.Equal(buf, clientPreface) {
   161  		tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface)
   162  	}
   163  }
   164  
   165  func newTestClientConn(t *testing.T, opts ...any) *testClientConn {
   166  	t.Helper()
   167  
   168  	tt := newTestTransport(t, opts...)
   169  	const singleUse = false
   170  	_, err := tt.tr.newClientConn(nil, singleUse)
   171  	if err != nil {
   172  		t.Fatalf("newClientConn: %v", err)
   173  	}
   174  
   175  	return tt.getConn()
   176  }
   177  
   178  // sync waits for the ClientConn under test to reach a stable state,
   179  // with all goroutines blocked on some input.
   180  func (tc *testClientConn) sync() {
   181  	tc.group.Wait()
   182  }
   183  
   184  // advance advances synthetic time by a duration.
   185  func (tc *testClientConn) advance(d time.Duration) {
   186  	tc.group.AdvanceTime(d)
   187  	tc.sync()
   188  }
   189  
   190  // hasFrame reports whether a frame is available to be read.
   191  func (tc *testClientConn) hasFrame() bool {
   192  	return len(tc.netconn.Peek()) > 0
   193  }
   194  
   195  // isClosed reports whether the peer has closed the connection.
   196  func (tc *testClientConn) isClosed() bool {
   197  	return tc.netconn.IsClosedByPeer()
   198  }
   199  
   200  // closeWrite causes the net.Conn used by the ClientConn to return a error
   201  // from Read calls.
   202  func (tc *testClientConn) closeWrite() {
   203  	tc.netconn.Close()
   204  }
   205  
   206  // testRequestBody is a Request.Body for use in tests.
   207  type testRequestBody struct {
   208  	tc   *testClientConn
   209  	gate gate
   210  
   211  	// At most one of buf or bytes can be set at any given time:
   212  	buf   bytes.Buffer // specific bytes to read from the body
   213  	bytes int          // body contains this many arbitrary bytes
   214  
   215  	err error // read error (comes after any available bytes)
   216  }
   217  
   218  func (tc *testClientConn) newRequestBody() *testRequestBody {
   219  	b := &testRequestBody{
   220  		tc:   tc,
   221  		gate: newGate(),
   222  	}
   223  	return b
   224  }
   225  
   226  func (b *testRequestBody) unlock() {
   227  	b.gate.unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil)
   228  }
   229  
   230  // Read is called by the ClientConn to read from a request body.
   231  func (b *testRequestBody) Read(p []byte) (n int, _ error) {
   232  	if err := b.gate.waitAndLock(context.Background()); err != nil {
   233  		return 0, err
   234  	}
   235  	defer b.unlock()
   236  	switch {
   237  	case b.buf.Len() > 0:
   238  		return b.buf.Read(p)
   239  	case b.bytes > 0:
   240  		if len(p) > b.bytes {
   241  			p = p[:b.bytes]
   242  		}
   243  		b.bytes -= len(p)
   244  		for i := range p {
   245  			p[i] = 'A'
   246  		}
   247  		return len(p), nil
   248  	default:
   249  		return 0, b.err
   250  	}
   251  }
   252  
   253  // Close is called by the ClientConn when it is done reading from a request body.
   254  func (b *testRequestBody) Close() error {
   255  	return nil
   256  }
   257  
   258  // writeBytes adds n arbitrary bytes to the body.
   259  func (b *testRequestBody) writeBytes(n int) {
   260  	defer b.tc.sync()
   261  	b.gate.lock()
   262  	defer b.unlock()
   263  	b.bytes += n
   264  	b.checkWrite()
   265  	b.tc.sync()
   266  }
   267  
   268  // Write adds bytes to the body.
   269  func (b *testRequestBody) Write(p []byte) (int, error) {
   270  	defer b.tc.sync()
   271  	b.gate.lock()
   272  	defer b.unlock()
   273  	n, err := b.buf.Write(p)
   274  	b.checkWrite()
   275  	return n, err
   276  }
   277  
   278  func (b *testRequestBody) checkWrite() {
   279  	if b.bytes > 0 && b.buf.Len() > 0 {
   280  		b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
   281  	}
   282  	if b.err != nil {
   283  		b.tc.t.Fatalf("can't write to request body after closeWithError")
   284  	}
   285  }
   286  
   287  // closeWithError sets an error which will be returned by Read.
   288  func (b *testRequestBody) closeWithError(err error) {
   289  	defer b.tc.sync()
   290  	b.gate.lock()
   291  	defer b.unlock()
   292  	b.err = err
   293  }
   294  
   295  // roundTrip starts a RoundTrip call.
   296  //
   297  // (Note that the RoundTrip won't complete until response headers are received,
   298  // the request times out, or some other terminal condition is reached.)
   299  func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
   300  	rt := &testRoundTrip{
   301  		t:     tc.t,
   302  		donec: make(chan struct{}),
   303  	}
   304  	tc.roundtrips = append(tc.roundtrips, rt)
   305  	go func() {
   306  		tc.group.Join()
   307  		defer close(rt.donec)
   308  		rt.resp, rt.respErr = tc.cc.roundTrip(req, func(cs *clientStream) {
   309  			rt.id.Store(cs.ID)
   310  		})
   311  	}()
   312  	tc.sync()
   313  
   314  	tc.t.Cleanup(func() {
   315  		if !rt.done() {
   316  			return
   317  		}
   318  		res, _ := rt.result()
   319  		if res != nil {
   320  			res.Body.Close()
   321  		}
   322  	})
   323  
   324  	return rt
   325  }
   326  
   327  func (tc *testClientConn) greet(settings ...Setting) {
   328  	tc.wantFrameType(FrameSettings)
   329  	tc.wantFrameType(FrameWindowUpdate)
   330  	tc.writeSettings(settings...)
   331  	tc.writeSettingsAck()
   332  	tc.wantFrameType(FrameSettings) // acknowledgement
   333  }
   334  
   335  // makeHeaderBlockFragment encodes headers in a form suitable for inclusion
   336  // in a HEADERS or CONTINUATION frame.
   337  //
   338  // It takes a list of alernating names and values.
   339  func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
   340  	if len(s)%2 != 0 {
   341  		tc.t.Fatalf("uneven list of header name/value pairs")
   342  	}
   343  	tc.encbuf.Reset()
   344  	for i := 0; i < len(s); i += 2 {
   345  		tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
   346  	}
   347  	return tc.encbuf.Bytes()
   348  }
   349  
   350  // inflowWindow returns the amount of inbound flow control available for a stream,
   351  // or for the connection if streamID is 0.
   352  func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
   353  	tc.cc.mu.Lock()
   354  	defer tc.cc.mu.Unlock()
   355  	if streamID == 0 {
   356  		return tc.cc.inflow.avail + tc.cc.inflow.unsent
   357  	}
   358  	cs := tc.cc.streams[streamID]
   359  	if cs == nil {
   360  		tc.t.Errorf("no stream with id %v", streamID)
   361  		return -1
   362  	}
   363  	return cs.inflow.avail + cs.inflow.unsent
   364  }
   365  
   366  // testRoundTrip manages a RoundTrip in progress.
   367  type testRoundTrip struct {
   368  	t       *testing.T
   369  	resp    *http.Response
   370  	respErr error
   371  	donec   chan struct{}
   372  	id      atomic.Uint32
   373  }
   374  
   375  // streamID returns the HTTP/2 stream ID of the request.
   376  func (rt *testRoundTrip) streamID() uint32 {
   377  	id := rt.id.Load()
   378  	if id == 0 {
   379  		panic("stream ID unknown")
   380  	}
   381  	return id
   382  }
   383  
   384  // done reports whether RoundTrip has returned.
   385  func (rt *testRoundTrip) done() bool {
   386  	select {
   387  	case <-rt.donec:
   388  		return true
   389  	default:
   390  		return false
   391  	}
   392  }
   393  
   394  // result returns the result of the RoundTrip.
   395  func (rt *testRoundTrip) result() (*http.Response, error) {
   396  	t := rt.t
   397  	t.Helper()
   398  	select {
   399  	case <-rt.donec:
   400  	default:
   401  		t.Fatalf("RoundTrip is not done; want it to be")
   402  	}
   403  	return rt.resp, rt.respErr
   404  }
   405  
   406  // response returns the response of a successful RoundTrip.
   407  // If the RoundTrip unexpectedly failed, it calls t.Fatal.
   408  func (rt *testRoundTrip) response() *http.Response {
   409  	t := rt.t
   410  	t.Helper()
   411  	resp, err := rt.result()
   412  	if err != nil {
   413  		t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
   414  	}
   415  	if resp == nil {
   416  		t.Fatalf("RoundTrip returned nil *Response and nil error")
   417  	}
   418  	return resp
   419  }
   420  
   421  // err returns the (possibly nil) error result of RoundTrip.
   422  func (rt *testRoundTrip) err() error {
   423  	t := rt.t
   424  	t.Helper()
   425  	_, err := rt.result()
   426  	return err
   427  }
   428  
   429  // wantStatus indicates the expected response StatusCode.
   430  func (rt *testRoundTrip) wantStatus(want int) {
   431  	t := rt.t
   432  	t.Helper()
   433  	if got := rt.response().StatusCode; got != want {
   434  		t.Fatalf("got response status %v, want %v", got, want)
   435  	}
   436  }
   437  
   438  // body reads the contents of the response body.
   439  func (rt *testRoundTrip) readBody() ([]byte, error) {
   440  	t := rt.t
   441  	t.Helper()
   442  	return io.ReadAll(rt.response().Body)
   443  }
   444  
   445  // wantBody indicates the expected response body.
   446  // (Note that this consumes the body.)
   447  func (rt *testRoundTrip) wantBody(want []byte) {
   448  	t := rt.t
   449  	t.Helper()
   450  	got, err := rt.readBody()
   451  	if err != nil {
   452  		t.Fatalf("unexpected error reading response body: %v", err)
   453  	}
   454  	if !bytes.Equal(got, want) {
   455  		t.Fatalf("unexpected response body:\ngot:  %q\nwant: %q", got, want)
   456  	}
   457  }
   458  
   459  // wantHeaders indicates the expected response headers.
   460  func (rt *testRoundTrip) wantHeaders(want http.Header) {
   461  	t := rt.t
   462  	t.Helper()
   463  	res := rt.response()
   464  	if diff := diffHeaders(res.Header, want); diff != "" {
   465  		t.Fatalf("unexpected response headers:\n%v", diff)
   466  	}
   467  }
   468  
   469  // wantTrailers indicates the expected response trailers.
   470  func (rt *testRoundTrip) wantTrailers(want http.Header) {
   471  	t := rt.t
   472  	t.Helper()
   473  	res := rt.response()
   474  	if diff := diffHeaders(res.Trailer, want); diff != "" {
   475  		t.Fatalf("unexpected response trailers:\n%v", diff)
   476  	}
   477  }
   478  
   479  func diffHeaders(got, want http.Header) string {
   480  	// nil and 0-length non-nil are equal.
   481  	if len(got) == 0 && len(want) == 0 {
   482  		return ""
   483  	}
   484  	// We could do a more sophisticated diff here.
   485  	// DeepEqual is good enough for now.
   486  	if reflect.DeepEqual(got, want) {
   487  		return ""
   488  	}
   489  	return fmt.Sprintf("got:  %v\nwant: %v", got, want)
   490  }
   491  
   492  // A testTransport allows testing Transport.RoundTrip against fake servers.
   493  // Tests that aren't specifically exercising RoundTrip's retry loop or connection pooling
   494  // should use testClientConn instead.
   495  type testTransport struct {
   496  	t     *testing.T
   497  	tr    *Transport
   498  	group *synctestGroup
   499  
   500  	ccs []*testClientConn
   501  }
   502  
   503  func newTestTransport(t *testing.T, opts ...any) *testTransport {
   504  	tt := &testTransport{
   505  		t:     t,
   506  		group: newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)),
   507  	}
   508  	tt.group.Join()
   509  
   510  	tr := &Transport{}
   511  	for _, o := range opts {
   512  		switch o := o.(type) {
   513  		case func(*http.Transport):
   514  			if tr.t1 == nil {
   515  				tr.t1 = &http.Transport{}
   516  			}
   517  			o(tr.t1)
   518  		case func(*Transport):
   519  			o(tr)
   520  		case *Transport:
   521  			tr = o
   522  		}
   523  	}
   524  	tt.tr = tr
   525  
   526  	tr.transportTestHooks = &transportTestHooks{
   527  		group: tt.group,
   528  		newclientconn: func(cc *ClientConn) {
   529  			tc := newTestClientConnFromClientConn(t, cc)
   530  			tt.ccs = append(tt.ccs, tc)
   531  		},
   532  	}
   533  
   534  	t.Cleanup(func() {
   535  		tt.sync()
   536  		if len(tt.ccs) > 0 {
   537  			t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
   538  		}
   539  		tt.group.Close(t)
   540  	})
   541  
   542  	return tt
   543  }
   544  
   545  func (tt *testTransport) sync() {
   546  	tt.group.Wait()
   547  }
   548  
   549  func (tt *testTransport) advance(d time.Duration) {
   550  	tt.group.AdvanceTime(d)
   551  	tt.sync()
   552  }
   553  
   554  func (tt *testTransport) hasConn() bool {
   555  	return len(tt.ccs) > 0
   556  }
   557  
   558  func (tt *testTransport) getConn() *testClientConn {
   559  	tt.t.Helper()
   560  	if len(tt.ccs) == 0 {
   561  		tt.t.Fatalf("no new ClientConns created; wanted one")
   562  	}
   563  	tc := tt.ccs[0]
   564  	tt.ccs = tt.ccs[1:]
   565  	tc.sync()
   566  	tc.readClientPreface()
   567  	tc.sync()
   568  	return tc
   569  }
   570  
   571  func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
   572  	rt := &testRoundTrip{
   573  		t:     tt.t,
   574  		donec: make(chan struct{}),
   575  	}
   576  	go func() {
   577  		tt.group.Join()
   578  		defer close(rt.donec)
   579  		rt.resp, rt.respErr = tt.tr.RoundTrip(req)
   580  	}()
   581  	tt.sync()
   582  
   583  	tt.t.Cleanup(func() {
   584  		if !rt.done() {
   585  			return
   586  		}
   587  		res, _ := rt.result()
   588  		if res != nil {
   589  			res.Body.Close()
   590  		}
   591  	})
   592  
   593  	return rt
   594  }
   595  

View as plain text