...

Source file src/golang.org/x/net/http2/connframes_test.go

Documentation: golang.org/x/net/http2

     1  // Copyright 2024 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package http2
     6  
     7  import (
     8  	"bytes"
     9  	"io"
    10  	"net/http"
    11  	"os"
    12  	"reflect"
    13  	"slices"
    14  	"testing"
    15  
    16  	"golang.org/x/net/http2/hpack"
    17  )
    18  
    19  type testConnFramer struct {
    20  	t   testing.TB
    21  	fr  *Framer
    22  	dec *hpack.Decoder
    23  }
    24  
    25  // readFrame reads the next frame.
    26  // It returns nil if the conn is closed or no frames are available.
    27  func (tf *testConnFramer) readFrame() Frame {
    28  	tf.t.Helper()
    29  	fr, err := tf.fr.ReadFrame()
    30  	if err == io.EOF || err == os.ErrDeadlineExceeded {
    31  		return nil
    32  	}
    33  	if err != nil {
    34  		tf.t.Fatalf("ReadFrame: %v", err)
    35  	}
    36  	return fr
    37  }
    38  
    39  type readFramer interface {
    40  	readFrame() Frame
    41  }
    42  
    43  // readFrame reads a frame of a specific type.
    44  func readFrame[T any](t testing.TB, framer readFramer) T {
    45  	t.Helper()
    46  	var v T
    47  	fr := framer.readFrame()
    48  	if fr == nil {
    49  		t.Fatalf("got no frame, want frame %T", v)
    50  	}
    51  	v, ok := fr.(T)
    52  	if !ok {
    53  		t.Fatalf("got frame %T, want %T", fr, v)
    54  	}
    55  	return v
    56  }
    57  
    58  // wantFrameType reads the next frame.
    59  // It produces an error if the frame type is not the expected value.
    60  func (tf *testConnFramer) wantFrameType(want FrameType) {
    61  	tf.t.Helper()
    62  	fr := tf.readFrame()
    63  	if fr == nil {
    64  		tf.t.Fatalf("got no frame, want frame %v", want)
    65  	}
    66  	if got := fr.Header().Type; got != want {
    67  		tf.t.Fatalf("got frame %v, want %v", got, want)
    68  	}
    69  }
    70  
    71  // wantUnorderedFrames reads frames until every condition in want has been satisfied.
    72  //
    73  // want is a list of func(*SomeFrame) bool.
    74  // wantUnorderedFrames will call each func with frames of the appropriate type
    75  // until the func returns true.
    76  // It calls t.Fatal if an unexpected frame is received (no func has that frame type,
    77  // or all funcs with that type have returned true), or if the framer runs out of frames
    78  // with unsatisfied funcs.
    79  //
    80  // Example:
    81  //
    82  //	// Read a SETTINGS frame, and any number of DATA frames for a stream.
    83  //	// The SETTINGS frame may appear anywhere in the sequence.
    84  //	// The last DATA frame must indicate the end of the stream.
    85  //	tf.wantUnorderedFrames(
    86  //		func(f *SettingsFrame) bool {
    87  //			return true
    88  //		},
    89  //		func(f *DataFrame) bool {
    90  //			return f.StreamEnded()
    91  //		},
    92  //	)
    93  func (tf *testConnFramer) wantUnorderedFrames(want ...any) {
    94  	tf.t.Helper()
    95  	want = slices.Clone(want)
    96  	seen := 0
    97  frame:
    98  	for seen < len(want) && !tf.t.Failed() {
    99  		fr := tf.readFrame()
   100  		if fr == nil {
   101  			break
   102  		}
   103  		for i, f := range want {
   104  			if f == nil {
   105  				continue
   106  			}
   107  			typ := reflect.TypeOf(f)
   108  			if typ.Kind() != reflect.Func ||
   109  				typ.NumIn() != 1 ||
   110  				typ.NumOut() != 1 ||
   111  				typ.Out(0) != reflect.TypeOf(true) {
   112  				tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
   113  			}
   114  			if typ.In(0) == reflect.TypeOf(fr) {
   115  				out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
   116  				if out[0].Bool() {
   117  					want[i] = nil
   118  					seen++
   119  				}
   120  				continue frame
   121  			}
   122  		}
   123  		tf.t.Errorf("got unexpected frame type %T", fr)
   124  	}
   125  	if seen < len(want) {
   126  		for _, f := range want {
   127  			if f == nil {
   128  				continue
   129  			}
   130  			tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
   131  		}
   132  		tf.t.Fatalf("did not see %v expected frame types", len(want)-seen)
   133  	}
   134  }
   135  
   136  type wantHeader struct {
   137  	streamID  uint32
   138  	endStream bool
   139  	header    http.Header
   140  }
   141  
   142  // wantHeaders reads a HEADERS frame and potential CONTINUATION frames,
   143  // and asserts that they contain the expected headers.
   144  func (tf *testConnFramer) wantHeaders(want wantHeader) {
   145  	tf.t.Helper()
   146  
   147  	hf := readFrame[*HeadersFrame](tf.t, tf)
   148  	if got, want := hf.StreamID, want.streamID; got != want {
   149  		tf.t.Fatalf("got stream ID %v, want %v", got, want)
   150  	}
   151  	if got, want := hf.StreamEnded(), want.endStream; got != want {
   152  		tf.t.Fatalf("got stream ended %v, want %v", got, want)
   153  	}
   154  
   155  	gotHeader := make(http.Header)
   156  	tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
   157  		gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value)
   158  	})
   159  	defer tf.dec.SetEmitFunc(nil)
   160  	if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil {
   161  		tf.t.Fatalf("decoding HEADERS frame: %v", err)
   162  	}
   163  	headersEnded := hf.HeadersEnded()
   164  	for !headersEnded {
   165  		cf := readFrame[*ContinuationFrame](tf.t, tf)
   166  		if cf == nil {
   167  			tf.t.Fatalf("got end of frames, want CONTINUATION")
   168  		}
   169  		if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil {
   170  			tf.t.Fatalf("decoding CONTINUATION frame: %v", err)
   171  		}
   172  		headersEnded = cf.HeadersEnded()
   173  	}
   174  	if err := tf.dec.Close(); err != nil {
   175  		tf.t.Fatalf("hpack decoding error: %v", err)
   176  	}
   177  
   178  	for k, v := range want.header {
   179  		if !reflect.DeepEqual(v, gotHeader[k]) {
   180  			tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
   181  		}
   182  	}
   183  }
   184  
   185  // decodeHeader supports some older server tests.
   186  // TODO: rewrite those tests to use newer, more convenient test APIs.
   187  func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) {
   188  	tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
   189  		if hf.Name == "date" {
   190  			return
   191  		}
   192  		pairs = append(pairs, [2]string{hf.Name, hf.Value})
   193  	})
   194  	defer tf.dec.SetEmitFunc(nil)
   195  	if _, err := tf.dec.Write(headerBlock); err != nil {
   196  		tf.t.Fatalf("hpack decoding error: %v", err)
   197  	}
   198  	if err := tf.dec.Close(); err != nil {
   199  		tf.t.Fatalf("hpack decoding error: %v", err)
   200  	}
   201  	return pairs
   202  }
   203  
   204  type wantData struct {
   205  	streamID  uint32
   206  	endStream bool
   207  	size      int
   208  	data      []byte
   209  	multiple  bool // data may be spread across multiple DATA frames
   210  }
   211  
   212  // wantData reads zero or more DATA frames, and asserts that they match the expectation.
   213  func (tf *testConnFramer) wantData(want wantData) {
   214  	tf.t.Helper()
   215  	gotSize := 0
   216  	gotEndStream := false
   217  	if want.data != nil {
   218  		want.size = len(want.data)
   219  	}
   220  	var gotData []byte
   221  	for {
   222  		fr := tf.readFrame()
   223  		if fr == nil {
   224  			break
   225  		}
   226  		data, ok := fr.(*DataFrame)
   227  		if !ok {
   228  			tf.t.Fatalf("got frame %T, want DataFrame", fr)
   229  		}
   230  		if want.data != nil {
   231  			gotData = append(gotData, data.Data()...)
   232  		}
   233  		gotSize += len(data.Data())
   234  		if data.StreamEnded() {
   235  			gotEndStream = true
   236  			break
   237  		}
   238  		if !want.endStream && gotSize >= want.size {
   239  			break
   240  		}
   241  		if !want.multiple {
   242  			break
   243  		}
   244  	}
   245  	if gotSize != want.size {
   246  		tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
   247  	}
   248  	if gotEndStream != want.endStream {
   249  		tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
   250  	}
   251  	if want.data != nil && !bytes.Equal(gotData, want.data) {
   252  		tf.t.Fatalf("got data %q, want %q", gotData, want.data)
   253  	}
   254  }
   255  
   256  func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
   257  	tf.t.Helper()
   258  	fr := readFrame[*RSTStreamFrame](tf.t, tf)
   259  	if fr.StreamID != streamID || fr.ErrCode != code {
   260  		tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", summarizeFrame(fr), streamID, code)
   261  	}
   262  }
   263  
   264  func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
   265  	fr := readFrame[*SettingsFrame](tf.t, tf)
   266  	if fr.Header().Flags.Has(FlagSettingsAck) {
   267  		tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
   268  	}
   269  	for wantID, wantVal := range want {
   270  		gotVal, ok := fr.Value(wantID)
   271  		if !ok {
   272  			tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
   273  		} else if gotVal != wantVal {
   274  			tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
   275  		}
   276  	}
   277  	if tf.t.Failed() {
   278  		tf.t.Fatalf("%v", fr)
   279  	}
   280  }
   281  
   282  func (tf *testConnFramer) wantSettingsAck() {
   283  	tf.t.Helper()
   284  	fr := readFrame[*SettingsFrame](tf.t, tf)
   285  	if !fr.Header().Flags.Has(FlagSettingsAck) {
   286  		tf.t.Fatal("Settings Frame didn't have ACK set")
   287  	}
   288  }
   289  
   290  func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) {
   291  	tf.t.Helper()
   292  	fr := readFrame[*GoAwayFrame](tf.t, tf)
   293  	if fr.LastStreamID != maxStreamID || fr.ErrCode != code {
   294  		tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", summarizeFrame(fr), maxStreamID, code)
   295  	}
   296  }
   297  
   298  func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) {
   299  	tf.t.Helper()
   300  	wu := readFrame[*WindowUpdateFrame](tf.t, tf)
   301  	if wu.FrameHeader.StreamID != streamID {
   302  		tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
   303  	}
   304  	if wu.Increment != incr {
   305  		tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
   306  	}
   307  }
   308  
   309  func (tf *testConnFramer) wantClosed() {
   310  	tf.t.Helper()
   311  	fr, err := tf.fr.ReadFrame()
   312  	if err == nil {
   313  		tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
   314  	}
   315  	if err == os.ErrDeadlineExceeded {
   316  		tf.t.Fatalf("connection is not closed; want it to be")
   317  	}
   318  }
   319  
   320  func (tf *testConnFramer) wantIdle() {
   321  	tf.t.Helper()
   322  	fr, err := tf.fr.ReadFrame()
   323  	if err == nil {
   324  		tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
   325  	}
   326  	if err != os.ErrDeadlineExceeded {
   327  		tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
   328  	}
   329  }
   330  
   331  func (tf *testConnFramer) writeSettings(settings ...Setting) {
   332  	tf.t.Helper()
   333  	if err := tf.fr.WriteSettings(settings...); err != nil {
   334  		tf.t.Fatal(err)
   335  	}
   336  }
   337  
   338  func (tf *testConnFramer) writeSettingsAck() {
   339  	tf.t.Helper()
   340  	if err := tf.fr.WriteSettingsAck(); err != nil {
   341  		tf.t.Fatal(err)
   342  	}
   343  }
   344  
   345  func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) {
   346  	tf.t.Helper()
   347  	if err := tf.fr.WriteData(streamID, endStream, data); err != nil {
   348  		tf.t.Fatal(err)
   349  	}
   350  }
   351  
   352  func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
   353  	tf.t.Helper()
   354  	if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
   355  		tf.t.Fatal(err)
   356  	}
   357  }
   358  
   359  func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) {
   360  	tf.t.Helper()
   361  	if err := tf.fr.WriteHeaders(p); err != nil {
   362  		tf.t.Fatal(err)
   363  	}
   364  }
   365  
   366  // writeHeadersMode writes header frames, as modified by mode:
   367  //
   368  //   - noHeader: Don't write the header.
   369  //   - oneHeader: Write a single HEADERS frame.
   370  //   - splitHeader: Write a HEADERS frame and CONTINUATION frame.
   371  func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) {
   372  	tf.t.Helper()
   373  	switch mode {
   374  	case noHeader:
   375  	case oneHeader:
   376  		tf.writeHeaders(p)
   377  	case splitHeader:
   378  		if len(p.BlockFragment) < 2 {
   379  			panic("too small")
   380  		}
   381  		contData := p.BlockFragment[1:]
   382  		contEnd := p.EndHeaders
   383  		p.BlockFragment = p.BlockFragment[:1]
   384  		p.EndHeaders = false
   385  		tf.writeHeaders(p)
   386  		tf.writeContinuation(p.StreamID, contEnd, contData)
   387  	default:
   388  		panic("bogus mode")
   389  	}
   390  }
   391  
   392  func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
   393  	tf.t.Helper()
   394  	if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
   395  		tf.t.Fatal(err)
   396  	}
   397  }
   398  
   399  func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) {
   400  	if err := tf.fr.WritePriority(id, p); err != nil {
   401  		tf.t.Fatal(err)
   402  	}
   403  }
   404  
   405  func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) {
   406  	tf.t.Helper()
   407  	if err := tf.fr.WriteRSTStream(streamID, code); err != nil {
   408  		tf.t.Fatal(err)
   409  	}
   410  }
   411  
   412  func (tf *testConnFramer) writePing(ack bool, data [8]byte) {
   413  	tf.t.Helper()
   414  	if err := tf.fr.WritePing(ack, data); err != nil {
   415  		tf.t.Fatal(err)
   416  	}
   417  }
   418  
   419  func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
   420  	tf.t.Helper()
   421  	if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
   422  		tf.t.Fatal(err)
   423  	}
   424  }
   425  
   426  func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) {
   427  	tf.t.Helper()
   428  	if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil {
   429  		tf.t.Fatal(err)
   430  	}
   431  }
   432  

View as plain text