1
2
3
4
5
6
7
8 package http2
9
10 import (
11 "bytes"
12 "context"
13 "crypto/tls"
14 "fmt"
15 "io"
16 "net/http"
17 "reflect"
18 "sync/atomic"
19 "testing"
20 "time"
21
22 "golang.org/x/net/http2/hpack"
23 )
24
25
26 func TestTestClientConn(t *testing.T) {
27
28 tc := newTestClientConn(t)
29
30
31
32
33
34 tc.greet()
35
36
37
38 body := tc.newRequestBody()
39 body.writeBytes(10)
40 body.closeWithError(io.EOF)
41
42
43
44 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
45 rt := tc.roundTrip(req)
46
47
48
49 tc.wantHeaders(wantHeader{
50 streamID: rt.streamID(),
51 endStream: false,
52 header: http.Header{
53 ":authority": []string{"dummy.tld"},
54 ":method": []string{"PUT"},
55 ":path": []string{"/"},
56 },
57 })
58
59 tc.wantData(wantData{
60 streamID: rt.streamID(),
61 endStream: true,
62 size: 10,
63 multiple: true,
64 })
65
66
67 tc.writeHeaders(HeadersFrameParam{
68 StreamID: rt.streamID(),
69 EndHeaders: true,
70 EndStream: true,
71 BlockFragment: tc.makeHeaderBlockFragment(
72 ":status", "200",
73 ),
74 })
75
76
77
78
79 rt.wantStatus(200)
80 rt.wantBody(nil)
81 }
82
83
84
85
86
87
88
89
90
91
92 type testClientConn struct {
93 t *testing.T
94
95 tr *Transport
96 fr *Framer
97 cc *ClientConn
98 group *synctestGroup
99 testConnFramer
100
101 encbuf bytes.Buffer
102 enc *hpack.Encoder
103
104 roundtrips []*testRoundTrip
105
106 netconn *synctestNetConn
107 }
108
109 func newTestClientConnFromClientConn(t *testing.T, cc *ClientConn) *testClientConn {
110 tc := &testClientConn{
111 t: t,
112 tr: cc.t,
113 cc: cc,
114 group: cc.t.transportTestHooks.group.(*synctestGroup),
115 }
116
117
118 var srv *synctestNetConn
119 if cc.tconn == nil {
120
121
122
123 cc.tconn, srv = synctestNetPipe(tc.group)
124 } else {
125
126
127 if tc, ok := cc.tconn.(*tls.Conn); ok {
128
129
130 cc.tconn = tc.NetConn()
131 }
132 srv = cc.tconn.(*synctestNetConn).peer
133 }
134
135 srv.SetReadDeadline(tc.group.Now())
136 srv.autoWait = true
137 tc.netconn = srv
138 tc.enc = hpack.NewEncoder(&tc.encbuf)
139 tc.fr = NewFramer(srv, srv)
140 tc.testConnFramer = testConnFramer{
141 t: t,
142 fr: tc.fr,
143 dec: hpack.NewDecoder(initialHeaderTableSize, nil),
144 }
145 tc.fr.SetMaxReadFrameSize(10 << 20)
146 t.Cleanup(func() {
147 tc.closeWrite()
148 })
149
150 return tc
151 }
152
153 func (tc *testClientConn) readClientPreface() {
154 tc.t.Helper()
155
156 buf := make([]byte, len(clientPreface))
157 if _, err := io.ReadFull(tc.netconn, buf); err != nil {
158 tc.t.Fatalf("reading preface: %v", err)
159 }
160 if !bytes.Equal(buf, clientPreface) {
161 tc.t.Fatalf("client preface: %q, want %q", buf, clientPreface)
162 }
163 }
164
165 func newTestClientConn(t *testing.T, opts ...any) *testClientConn {
166 t.Helper()
167
168 tt := newTestTransport(t, opts...)
169 const singleUse = false
170 _, err := tt.tr.newClientConn(nil, singleUse)
171 if err != nil {
172 t.Fatalf("newClientConn: %v", err)
173 }
174
175 return tt.getConn()
176 }
177
178
179
180 func (tc *testClientConn) sync() {
181 tc.group.Wait()
182 }
183
184
185 func (tc *testClientConn) advance(d time.Duration) {
186 tc.group.AdvanceTime(d)
187 tc.sync()
188 }
189
190
191 func (tc *testClientConn) hasFrame() bool {
192 return len(tc.netconn.Peek()) > 0
193 }
194
195
196 func (tc *testClientConn) isClosed() bool {
197 return tc.netconn.IsClosedByPeer()
198 }
199
200
201
202 func (tc *testClientConn) closeWrite() {
203 tc.netconn.Close()
204 }
205
206
207 type testRequestBody struct {
208 tc *testClientConn
209 gate gate
210
211
212 buf bytes.Buffer
213 bytes int
214
215 err error
216 }
217
218 func (tc *testClientConn) newRequestBody() *testRequestBody {
219 b := &testRequestBody{
220 tc: tc,
221 gate: newGate(),
222 }
223 return b
224 }
225
226 func (b *testRequestBody) unlock() {
227 b.gate.unlock(b.buf.Len() > 0 || b.bytes > 0 || b.err != nil)
228 }
229
230
231 func (b *testRequestBody) Read(p []byte) (n int, _ error) {
232 if err := b.gate.waitAndLock(context.Background()); err != nil {
233 return 0, err
234 }
235 defer b.unlock()
236 switch {
237 case b.buf.Len() > 0:
238 return b.buf.Read(p)
239 case b.bytes > 0:
240 if len(p) > b.bytes {
241 p = p[:b.bytes]
242 }
243 b.bytes -= len(p)
244 for i := range p {
245 p[i] = 'A'
246 }
247 return len(p), nil
248 default:
249 return 0, b.err
250 }
251 }
252
253
254 func (b *testRequestBody) Close() error {
255 return nil
256 }
257
258
259 func (b *testRequestBody) writeBytes(n int) {
260 defer b.tc.sync()
261 b.gate.lock()
262 defer b.unlock()
263 b.bytes += n
264 b.checkWrite()
265 b.tc.sync()
266 }
267
268
269 func (b *testRequestBody) Write(p []byte) (int, error) {
270 defer b.tc.sync()
271 b.gate.lock()
272 defer b.unlock()
273 n, err := b.buf.Write(p)
274 b.checkWrite()
275 return n, err
276 }
277
278 func (b *testRequestBody) checkWrite() {
279 if b.bytes > 0 && b.buf.Len() > 0 {
280 b.tc.t.Fatalf("can't interleave Write and writeBytes on request body")
281 }
282 if b.err != nil {
283 b.tc.t.Fatalf("can't write to request body after closeWithError")
284 }
285 }
286
287
288 func (b *testRequestBody) closeWithError(err error) {
289 defer b.tc.sync()
290 b.gate.lock()
291 defer b.unlock()
292 b.err = err
293 }
294
295
296
297
298
299 func (tc *testClientConn) roundTrip(req *http.Request) *testRoundTrip {
300 rt := &testRoundTrip{
301 t: tc.t,
302 donec: make(chan struct{}),
303 }
304 tc.roundtrips = append(tc.roundtrips, rt)
305 go func() {
306 tc.group.Join()
307 defer close(rt.donec)
308 rt.resp, rt.respErr = tc.cc.roundTrip(req, func(cs *clientStream) {
309 rt.id.Store(cs.ID)
310 })
311 }()
312 tc.sync()
313
314 tc.t.Cleanup(func() {
315 if !rt.done() {
316 return
317 }
318 res, _ := rt.result()
319 if res != nil {
320 res.Body.Close()
321 }
322 })
323
324 return rt
325 }
326
327 func (tc *testClientConn) greet(settings ...Setting) {
328 tc.wantFrameType(FrameSettings)
329 tc.wantFrameType(FrameWindowUpdate)
330 tc.writeSettings(settings...)
331 tc.writeSettingsAck()
332 tc.wantFrameType(FrameSettings)
333 }
334
335
336
337
338
339 func (tc *testClientConn) makeHeaderBlockFragment(s ...string) []byte {
340 if len(s)%2 != 0 {
341 tc.t.Fatalf("uneven list of header name/value pairs")
342 }
343 tc.encbuf.Reset()
344 for i := 0; i < len(s); i += 2 {
345 tc.enc.WriteField(hpack.HeaderField{Name: s[i], Value: s[i+1]})
346 }
347 return tc.encbuf.Bytes()
348 }
349
350
351
352 func (tc *testClientConn) inflowWindow(streamID uint32) int32 {
353 tc.cc.mu.Lock()
354 defer tc.cc.mu.Unlock()
355 if streamID == 0 {
356 return tc.cc.inflow.avail + tc.cc.inflow.unsent
357 }
358 cs := tc.cc.streams[streamID]
359 if cs == nil {
360 tc.t.Errorf("no stream with id %v", streamID)
361 return -1
362 }
363 return cs.inflow.avail + cs.inflow.unsent
364 }
365
366
367 type testRoundTrip struct {
368 t *testing.T
369 resp *http.Response
370 respErr error
371 donec chan struct{}
372 id atomic.Uint32
373 }
374
375
376 func (rt *testRoundTrip) streamID() uint32 {
377 id := rt.id.Load()
378 if id == 0 {
379 panic("stream ID unknown")
380 }
381 return id
382 }
383
384
385 func (rt *testRoundTrip) done() bool {
386 select {
387 case <-rt.donec:
388 return true
389 default:
390 return false
391 }
392 }
393
394
395 func (rt *testRoundTrip) result() (*http.Response, error) {
396 t := rt.t
397 t.Helper()
398 select {
399 case <-rt.donec:
400 default:
401 t.Fatalf("RoundTrip is not done; want it to be")
402 }
403 return rt.resp, rt.respErr
404 }
405
406
407
408 func (rt *testRoundTrip) response() *http.Response {
409 t := rt.t
410 t.Helper()
411 resp, err := rt.result()
412 if err != nil {
413 t.Fatalf("RoundTrip returned unexpected error: %v", rt.respErr)
414 }
415 if resp == nil {
416 t.Fatalf("RoundTrip returned nil *Response and nil error")
417 }
418 return resp
419 }
420
421
422 func (rt *testRoundTrip) err() error {
423 t := rt.t
424 t.Helper()
425 _, err := rt.result()
426 return err
427 }
428
429
430 func (rt *testRoundTrip) wantStatus(want int) {
431 t := rt.t
432 t.Helper()
433 if got := rt.response().StatusCode; got != want {
434 t.Fatalf("got response status %v, want %v", got, want)
435 }
436 }
437
438
439 func (rt *testRoundTrip) readBody() ([]byte, error) {
440 t := rt.t
441 t.Helper()
442 return io.ReadAll(rt.response().Body)
443 }
444
445
446
447 func (rt *testRoundTrip) wantBody(want []byte) {
448 t := rt.t
449 t.Helper()
450 got, err := rt.readBody()
451 if err != nil {
452 t.Fatalf("unexpected error reading response body: %v", err)
453 }
454 if !bytes.Equal(got, want) {
455 t.Fatalf("unexpected response body:\ngot: %q\nwant: %q", got, want)
456 }
457 }
458
459
460 func (rt *testRoundTrip) wantHeaders(want http.Header) {
461 t := rt.t
462 t.Helper()
463 res := rt.response()
464 if diff := diffHeaders(res.Header, want); diff != "" {
465 t.Fatalf("unexpected response headers:\n%v", diff)
466 }
467 }
468
469
470 func (rt *testRoundTrip) wantTrailers(want http.Header) {
471 t := rt.t
472 t.Helper()
473 res := rt.response()
474 if diff := diffHeaders(res.Trailer, want); diff != "" {
475 t.Fatalf("unexpected response trailers:\n%v", diff)
476 }
477 }
478
479 func diffHeaders(got, want http.Header) string {
480
481 if len(got) == 0 && len(want) == 0 {
482 return ""
483 }
484
485
486 if reflect.DeepEqual(got, want) {
487 return ""
488 }
489 return fmt.Sprintf("got: %v\nwant: %v", got, want)
490 }
491
492
493
494
495 type testTransport struct {
496 t *testing.T
497 tr *Transport
498 group *synctestGroup
499
500 ccs []*testClientConn
501 }
502
503 func newTestTransport(t *testing.T, opts ...any) *testTransport {
504 tt := &testTransport{
505 t: t,
506 group: newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)),
507 }
508 tt.group.Join()
509
510 tr := &Transport{}
511 for _, o := range opts {
512 switch o := o.(type) {
513 case func(*http.Transport):
514 if tr.t1 == nil {
515 tr.t1 = &http.Transport{}
516 }
517 o(tr.t1)
518 case func(*Transport):
519 o(tr)
520 case *Transport:
521 tr = o
522 }
523 }
524 tt.tr = tr
525
526 tr.transportTestHooks = &transportTestHooks{
527 group: tt.group,
528 newclientconn: func(cc *ClientConn) {
529 tc := newTestClientConnFromClientConn(t, cc)
530 tt.ccs = append(tt.ccs, tc)
531 },
532 }
533
534 t.Cleanup(func() {
535 tt.sync()
536 if len(tt.ccs) > 0 {
537 t.Fatalf("%v test ClientConns created, but not examined by test", len(tt.ccs))
538 }
539 tt.group.Close(t)
540 })
541
542 return tt
543 }
544
545 func (tt *testTransport) sync() {
546 tt.group.Wait()
547 }
548
549 func (tt *testTransport) advance(d time.Duration) {
550 tt.group.AdvanceTime(d)
551 tt.sync()
552 }
553
554 func (tt *testTransport) hasConn() bool {
555 return len(tt.ccs) > 0
556 }
557
558 func (tt *testTransport) getConn() *testClientConn {
559 tt.t.Helper()
560 if len(tt.ccs) == 0 {
561 tt.t.Fatalf("no new ClientConns created; wanted one")
562 }
563 tc := tt.ccs[0]
564 tt.ccs = tt.ccs[1:]
565 tc.sync()
566 tc.readClientPreface()
567 tc.sync()
568 return tc
569 }
570
571 func (tt *testTransport) roundTrip(req *http.Request) *testRoundTrip {
572 rt := &testRoundTrip{
573 t: tt.t,
574 donec: make(chan struct{}),
575 }
576 go func() {
577 tt.group.Join()
578 defer close(rt.donec)
579 rt.resp, rt.respErr = tt.tr.RoundTrip(req)
580 }()
581 tt.sync()
582
583 tt.t.Cleanup(func() {
584 if !rt.done() {
585 return
586 }
587 res, _ := rt.result()
588 if res != nil {
589 res.Body.Close()
590 }
591 })
592
593 return rt
594 }
595
View as plain text