1
2
3
4
5 package http2
6
7 import (
8 "bytes"
9 "io"
10 "net/http"
11 "os"
12 "reflect"
13 "slices"
14 "testing"
15
16 "golang.org/x/net/http2/hpack"
17 )
18
19 type testConnFramer struct {
20 t testing.TB
21 fr *Framer
22 dec *hpack.Decoder
23 }
24
25
26
27 func (tf *testConnFramer) readFrame() Frame {
28 tf.t.Helper()
29 fr, err := tf.fr.ReadFrame()
30 if err == io.EOF || err == os.ErrDeadlineExceeded {
31 return nil
32 }
33 if err != nil {
34 tf.t.Fatalf("ReadFrame: %v", err)
35 }
36 return fr
37 }
38
39 type readFramer interface {
40 readFrame() Frame
41 }
42
43
44 func readFrame[T any](t testing.TB, framer readFramer) T {
45 t.Helper()
46 var v T
47 fr := framer.readFrame()
48 if fr == nil {
49 t.Fatalf("got no frame, want frame %T", v)
50 }
51 v, ok := fr.(T)
52 if !ok {
53 t.Fatalf("got frame %T, want %T", fr, v)
54 }
55 return v
56 }
57
58
59
60 func (tf *testConnFramer) wantFrameType(want FrameType) {
61 tf.t.Helper()
62 fr := tf.readFrame()
63 if fr == nil {
64 tf.t.Fatalf("got no frame, want frame %v", want)
65 }
66 if got := fr.Header().Type; got != want {
67 tf.t.Fatalf("got frame %v, want %v", got, want)
68 }
69 }
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93 func (tf *testConnFramer) wantUnorderedFrames(want ...any) {
94 tf.t.Helper()
95 want = slices.Clone(want)
96 seen := 0
97 frame:
98 for seen < len(want) && !tf.t.Failed() {
99 fr := tf.readFrame()
100 if fr == nil {
101 break
102 }
103 for i, f := range want {
104 if f == nil {
105 continue
106 }
107 typ := reflect.TypeOf(f)
108 if typ.Kind() != reflect.Func ||
109 typ.NumIn() != 1 ||
110 typ.NumOut() != 1 ||
111 typ.Out(0) != reflect.TypeOf(true) {
112 tf.t.Fatalf("expected func(*SomeFrame) bool, got %T", f)
113 }
114 if typ.In(0) == reflect.TypeOf(fr) {
115 out := reflect.ValueOf(f).Call([]reflect.Value{reflect.ValueOf(fr)})
116 if out[0].Bool() {
117 want[i] = nil
118 seen++
119 }
120 continue frame
121 }
122 }
123 tf.t.Errorf("got unexpected frame type %T", fr)
124 }
125 if seen < len(want) {
126 for _, f := range want {
127 if f == nil {
128 continue
129 }
130 tf.t.Errorf("did not see expected frame: %v", reflect.TypeOf(f).In(0))
131 }
132 tf.t.Fatalf("did not see %v expected frame types", len(want)-seen)
133 }
134 }
135
136 type wantHeader struct {
137 streamID uint32
138 endStream bool
139 header http.Header
140 }
141
142
143
144 func (tf *testConnFramer) wantHeaders(want wantHeader) {
145 tf.t.Helper()
146
147 hf := readFrame[*HeadersFrame](tf.t, tf)
148 if got, want := hf.StreamID, want.streamID; got != want {
149 tf.t.Fatalf("got stream ID %v, want %v", got, want)
150 }
151 if got, want := hf.StreamEnded(), want.endStream; got != want {
152 tf.t.Fatalf("got stream ended %v, want %v", got, want)
153 }
154
155 gotHeader := make(http.Header)
156 tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
157 gotHeader[hf.Name] = append(gotHeader[hf.Name], hf.Value)
158 })
159 defer tf.dec.SetEmitFunc(nil)
160 if _, err := tf.dec.Write(hf.HeaderBlockFragment()); err != nil {
161 tf.t.Fatalf("decoding HEADERS frame: %v", err)
162 }
163 headersEnded := hf.HeadersEnded()
164 for !headersEnded {
165 cf := readFrame[*ContinuationFrame](tf.t, tf)
166 if cf == nil {
167 tf.t.Fatalf("got end of frames, want CONTINUATION")
168 }
169 if _, err := tf.dec.Write(cf.HeaderBlockFragment()); err != nil {
170 tf.t.Fatalf("decoding CONTINUATION frame: %v", err)
171 }
172 headersEnded = cf.HeadersEnded()
173 }
174 if err := tf.dec.Close(); err != nil {
175 tf.t.Fatalf("hpack decoding error: %v", err)
176 }
177
178 for k, v := range want.header {
179 if !reflect.DeepEqual(v, gotHeader[k]) {
180 tf.t.Fatalf("got header %q = %q; want %q", k, v, gotHeader[k])
181 }
182 }
183 }
184
185
186
187 func (tf *testConnFramer) decodeHeader(headerBlock []byte) (pairs [][2]string) {
188 tf.dec.SetEmitFunc(func(hf hpack.HeaderField) {
189 if hf.Name == "date" {
190 return
191 }
192 pairs = append(pairs, [2]string{hf.Name, hf.Value})
193 })
194 defer tf.dec.SetEmitFunc(nil)
195 if _, err := tf.dec.Write(headerBlock); err != nil {
196 tf.t.Fatalf("hpack decoding error: %v", err)
197 }
198 if err := tf.dec.Close(); err != nil {
199 tf.t.Fatalf("hpack decoding error: %v", err)
200 }
201 return pairs
202 }
203
204 type wantData struct {
205 streamID uint32
206 endStream bool
207 size int
208 data []byte
209 multiple bool
210 }
211
212
213 func (tf *testConnFramer) wantData(want wantData) {
214 tf.t.Helper()
215 gotSize := 0
216 gotEndStream := false
217 if want.data != nil {
218 want.size = len(want.data)
219 }
220 var gotData []byte
221 for {
222 fr := tf.readFrame()
223 if fr == nil {
224 break
225 }
226 data, ok := fr.(*DataFrame)
227 if !ok {
228 tf.t.Fatalf("got frame %T, want DataFrame", fr)
229 }
230 if want.data != nil {
231 gotData = append(gotData, data.Data()...)
232 }
233 gotSize += len(data.Data())
234 if data.StreamEnded() {
235 gotEndStream = true
236 break
237 }
238 if !want.endStream && gotSize >= want.size {
239 break
240 }
241 if !want.multiple {
242 break
243 }
244 }
245 if gotSize != want.size {
246 tf.t.Fatalf("got %v bytes of DATA frames, want %v", gotSize, want.size)
247 }
248 if gotEndStream != want.endStream {
249 tf.t.Fatalf("after %v bytes of DATA frames, got END_STREAM=%v; want %v", gotSize, gotEndStream, want.endStream)
250 }
251 if want.data != nil && !bytes.Equal(gotData, want.data) {
252 tf.t.Fatalf("got data %q, want %q", gotData, want.data)
253 }
254 }
255
256 func (tf *testConnFramer) wantRSTStream(streamID uint32, code ErrCode) {
257 tf.t.Helper()
258 fr := readFrame[*RSTStreamFrame](tf.t, tf)
259 if fr.StreamID != streamID || fr.ErrCode != code {
260 tf.t.Fatalf("got %v, want RST_STREAM StreamID=%v, code=%v", summarizeFrame(fr), streamID, code)
261 }
262 }
263
264 func (tf *testConnFramer) wantSettings(want map[SettingID]uint32) {
265 fr := readFrame[*SettingsFrame](tf.t, tf)
266 if fr.Header().Flags.Has(FlagSettingsAck) {
267 tf.t.Errorf("got SETTINGS frame with ACK set, want no ACK")
268 }
269 for wantID, wantVal := range want {
270 gotVal, ok := fr.Value(wantID)
271 if !ok {
272 tf.t.Errorf("SETTINGS: %v is not set, want %v", wantID, wantVal)
273 } else if gotVal != wantVal {
274 tf.t.Errorf("SETTINGS: %v is %v, want %v", wantID, gotVal, wantVal)
275 }
276 }
277 if tf.t.Failed() {
278 tf.t.Fatalf("%v", fr)
279 }
280 }
281
282 func (tf *testConnFramer) wantSettingsAck() {
283 tf.t.Helper()
284 fr := readFrame[*SettingsFrame](tf.t, tf)
285 if !fr.Header().Flags.Has(FlagSettingsAck) {
286 tf.t.Fatal("Settings Frame didn't have ACK set")
287 }
288 }
289
290 func (tf *testConnFramer) wantGoAway(maxStreamID uint32, code ErrCode) {
291 tf.t.Helper()
292 fr := readFrame[*GoAwayFrame](tf.t, tf)
293 if fr.LastStreamID != maxStreamID || fr.ErrCode != code {
294 tf.t.Fatalf("got %v, want GOAWAY LastStreamID=%v, code=%v", summarizeFrame(fr), maxStreamID, code)
295 }
296 }
297
298 func (tf *testConnFramer) wantWindowUpdate(streamID, incr uint32) {
299 tf.t.Helper()
300 wu := readFrame[*WindowUpdateFrame](tf.t, tf)
301 if wu.FrameHeader.StreamID != streamID {
302 tf.t.Fatalf("WindowUpdate StreamID = %d; want %d", wu.FrameHeader.StreamID, streamID)
303 }
304 if wu.Increment != incr {
305 tf.t.Fatalf("WindowUpdate increment = %d; want %d", wu.Increment, incr)
306 }
307 }
308
309 func (tf *testConnFramer) wantClosed() {
310 tf.t.Helper()
311 fr, err := tf.fr.ReadFrame()
312 if err == nil {
313 tf.t.Fatalf("got unexpected frame (want closed connection): %v", fr)
314 }
315 if err == os.ErrDeadlineExceeded {
316 tf.t.Fatalf("connection is not closed; want it to be")
317 }
318 }
319
320 func (tf *testConnFramer) wantIdle() {
321 tf.t.Helper()
322 fr, err := tf.fr.ReadFrame()
323 if err == nil {
324 tf.t.Fatalf("got unexpected frame (want idle connection): %v", fr)
325 }
326 if err != os.ErrDeadlineExceeded {
327 tf.t.Fatalf("got unexpected frame error (want idle connection): %v", err)
328 }
329 }
330
331 func (tf *testConnFramer) writeSettings(settings ...Setting) {
332 tf.t.Helper()
333 if err := tf.fr.WriteSettings(settings...); err != nil {
334 tf.t.Fatal(err)
335 }
336 }
337
338 func (tf *testConnFramer) writeSettingsAck() {
339 tf.t.Helper()
340 if err := tf.fr.WriteSettingsAck(); err != nil {
341 tf.t.Fatal(err)
342 }
343 }
344
345 func (tf *testConnFramer) writeData(streamID uint32, endStream bool, data []byte) {
346 tf.t.Helper()
347 if err := tf.fr.WriteData(streamID, endStream, data); err != nil {
348 tf.t.Fatal(err)
349 }
350 }
351
352 func (tf *testConnFramer) writeDataPadded(streamID uint32, endStream bool, data, pad []byte) {
353 tf.t.Helper()
354 if err := tf.fr.WriteDataPadded(streamID, endStream, data, pad); err != nil {
355 tf.t.Fatal(err)
356 }
357 }
358
359 func (tf *testConnFramer) writeHeaders(p HeadersFrameParam) {
360 tf.t.Helper()
361 if err := tf.fr.WriteHeaders(p); err != nil {
362 tf.t.Fatal(err)
363 }
364 }
365
366
367
368
369
370
371 func (tf *testConnFramer) writeHeadersMode(mode headerType, p HeadersFrameParam) {
372 tf.t.Helper()
373 switch mode {
374 case noHeader:
375 case oneHeader:
376 tf.writeHeaders(p)
377 case splitHeader:
378 if len(p.BlockFragment) < 2 {
379 panic("too small")
380 }
381 contData := p.BlockFragment[1:]
382 contEnd := p.EndHeaders
383 p.BlockFragment = p.BlockFragment[:1]
384 p.EndHeaders = false
385 tf.writeHeaders(p)
386 tf.writeContinuation(p.StreamID, contEnd, contData)
387 default:
388 panic("bogus mode")
389 }
390 }
391
392 func (tf *testConnFramer) writeContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) {
393 tf.t.Helper()
394 if err := tf.fr.WriteContinuation(streamID, endHeaders, headerBlockFragment); err != nil {
395 tf.t.Fatal(err)
396 }
397 }
398
399 func (tf *testConnFramer) writePriority(id uint32, p PriorityParam) {
400 if err := tf.fr.WritePriority(id, p); err != nil {
401 tf.t.Fatal(err)
402 }
403 }
404
405 func (tf *testConnFramer) writeRSTStream(streamID uint32, code ErrCode) {
406 tf.t.Helper()
407 if err := tf.fr.WriteRSTStream(streamID, code); err != nil {
408 tf.t.Fatal(err)
409 }
410 }
411
412 func (tf *testConnFramer) writePing(ack bool, data [8]byte) {
413 tf.t.Helper()
414 if err := tf.fr.WritePing(ack, data); err != nil {
415 tf.t.Fatal(err)
416 }
417 }
418
419 func (tf *testConnFramer) writeGoAway(maxStreamID uint32, code ErrCode, debugData []byte) {
420 tf.t.Helper()
421 if err := tf.fr.WriteGoAway(maxStreamID, code, debugData); err != nil {
422 tf.t.Fatal(err)
423 }
424 }
425
426 func (tf *testConnFramer) writeWindowUpdate(streamID, incr uint32) {
427 tf.t.Helper()
428 if err := tf.fr.WriteWindowUpdate(streamID, incr); err != nil {
429 tf.t.Fatal(err)
430 }
431 }
432
View as plain text