1
2
3
4
5 package http2
6
7 import (
8 "bufio"
9 "bytes"
10 "compress/gzip"
11 "context"
12 "crypto/tls"
13 "encoding/hex"
14 "errors"
15 "flag"
16 "fmt"
17 "io"
18 "io/fs"
19 "log"
20 "math/rand"
21 "net"
22 "net/http"
23 "net/http/httptest"
24 "net/http/httptrace"
25 "net/textproto"
26 "net/url"
27 "os"
28 "reflect"
29 "runtime"
30 "sort"
31 "strconv"
32 "strings"
33 "sync"
34 "sync/atomic"
35 "testing"
36 "time"
37
38 "golang.org/x/net/http2/hpack"
39 )
40
41 var (
42 extNet = flag.Bool("extnet", false, "do external network tests")
43 transportHost = flag.String("transporthost", "http2.golang.org", "hostname to use for TestTransport")
44 insecure = flag.Bool("insecure", false, "insecure TLS dials")
45 )
46
47 var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}
48
49 var canceledCtx context.Context
50
51 func init() {
52 ctx, cancel := context.WithCancel(context.Background())
53 cancel()
54 canceledCtx = ctx
55 }
56
57 func TestTransportExternal(t *testing.T) {
58 if !*extNet {
59 t.Skip("skipping external network test")
60 }
61 req, _ := http.NewRequest("GET", "https://"+*transportHost+"/", nil)
62 rt := &Transport{TLSClientConfig: tlsConfigInsecure}
63 res, err := rt.RoundTrip(req)
64 if err != nil {
65 t.Fatalf("%v", err)
66 }
67 res.Write(os.Stdout)
68 }
69
70 type fakeTLSConn struct {
71 net.Conn
72 }
73
74 func (c *fakeTLSConn) ConnectionState() tls.ConnectionState {
75 return tls.ConnectionState{
76 Version: tls.VersionTLS12,
77 CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
78 }
79 }
80
81 func startH2cServer(t *testing.T) net.Listener {
82 h2Server := &Server{}
83 l := newLocalListener(t)
84 go func() {
85 conn, err := l.Accept()
86 if err != nil {
87 t.Error(err)
88 return
89 }
90 h2Server.ServeConn(&fakeTLSConn{conn}, &ServeConnOpts{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91 fmt.Fprintf(w, "Hello, %v, http: %v", r.URL.Path, r.TLS == nil)
92 })})
93 }()
94 return l
95 }
96
97 func TestIdleConnTimeout(t *testing.T) {
98 for _, test := range []struct {
99 name string
100 idleConnTimeout time.Duration
101 wait time.Duration
102 baseTransport *http.Transport
103 wantNewConn bool
104 }{{
105 name: "NoExpiry",
106 idleConnTimeout: 2 * time.Second,
107 wait: 1 * time.Second,
108 baseTransport: nil,
109 wantNewConn: false,
110 }, {
111 name: "H2TransportTimeoutExpires",
112 idleConnTimeout: 1 * time.Second,
113 wait: 2 * time.Second,
114 baseTransport: nil,
115 wantNewConn: true,
116 }, {
117 name: "H1TransportTimeoutExpires",
118 idleConnTimeout: 0 * time.Second,
119 wait: 1 * time.Second,
120 baseTransport: &http.Transport{
121 IdleConnTimeout: 2 * time.Second,
122 },
123 wantNewConn: false,
124 }} {
125 t.Run(test.name, func(t *testing.T) {
126 tt := newTestTransport(t, func(tr *Transport) {
127 tr.IdleConnTimeout = test.idleConnTimeout
128 })
129 var tc *testClientConn
130 for i := 0; i < 3; i++ {
131 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
132 rt := tt.roundTrip(req)
133
134
135
136
137 wantConn := i == 0 || test.wantNewConn
138 if has := tt.hasConn(); has != wantConn {
139 t.Fatalf("request %v: hasConn=%v, want %v", i, has, wantConn)
140 }
141 if wantConn {
142 tc = tt.getConn()
143
144
145 tc.wantFrameType(FrameSettings)
146 tc.wantFrameType(FrameWindowUpdate)
147 tc.writeSettings()
148 }
149 if tt.hasConn() {
150 t.Fatalf("request %v: Transport has more than one conn", i)
151 }
152
153
154 hf := readFrame[*HeadersFrame](t, tc)
155 tc.writeHeaders(HeadersFrameParam{
156 StreamID: hf.StreamID,
157 EndHeaders: true,
158 EndStream: true,
159 BlockFragment: tc.makeHeaderBlockFragment(
160 ":status", "200",
161 ),
162 })
163 rt.wantStatus(200)
164
165
166 if wantConn {
167 tc.wantFrameType(FrameSettings)
168 }
169
170 tt.advance(test.wait)
171 if got, want := tc.isClosed(), test.wantNewConn; got != want {
172 t.Fatalf("after waiting %v, conn closed=%v; want %v", test.wait, got, want)
173 }
174 }
175 })
176 }
177 }
178
179 func TestTransportH2c(t *testing.T) {
180 l := startH2cServer(t)
181 defer l.Close()
182 req, err := http.NewRequest("GET", "http://"+l.Addr().String()+"/foobar", nil)
183 if err != nil {
184 t.Fatal(err)
185 }
186 var gotConnCnt int32
187 trace := &httptrace.ClientTrace{
188 GotConn: func(connInfo httptrace.GotConnInfo) {
189 if !connInfo.Reused {
190 atomic.AddInt32(&gotConnCnt, 1)
191 }
192 },
193 }
194 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
195 tr := &Transport{
196 AllowHTTP: true,
197 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
198 return net.Dial(network, addr)
199 },
200 }
201 res, err := tr.RoundTrip(req)
202 if err != nil {
203 t.Fatal(err)
204 }
205 if res.ProtoMajor != 2 {
206 t.Fatal("proto not h2c")
207 }
208 body, err := io.ReadAll(res.Body)
209 if err != nil {
210 t.Fatal(err)
211 }
212 if got, want := string(body), "Hello, /foobar, http: true"; got != want {
213 t.Fatalf("response got %v, want %v", got, want)
214 }
215 if got, want := gotConnCnt, int32(1); got != want {
216 t.Errorf("Too many got connections: %d", gotConnCnt)
217 }
218 }
219
220 func TestTransport(t *testing.T) {
221 const body = "sup"
222 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
223 io.WriteString(w, body)
224 })
225
226 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
227 defer tr.CloseIdleConnections()
228
229 u, err := url.Parse(ts.URL)
230 if err != nil {
231 t.Fatal(err)
232 }
233 for i, m := range []string{"GET", ""} {
234 req := &http.Request{
235 Method: m,
236 URL: u,
237 }
238 res, err := tr.RoundTrip(req)
239 if err != nil {
240 t.Fatalf("%d: %s", i, err)
241 }
242
243 t.Logf("%d: Got res: %+v", i, res)
244 if g, w := res.StatusCode, 200; g != w {
245 t.Errorf("%d: StatusCode = %v; want %v", i, g, w)
246 }
247 if g, w := res.Status, "200 OK"; g != w {
248 t.Errorf("%d: Status = %q; want %q", i, g, w)
249 }
250 wantHeader := http.Header{
251 "Content-Length": []string{"3"},
252 "Content-Type": []string{"text/plain; charset=utf-8"},
253 "Date": []string{"XXX"},
254 }
255 cleanDate(res)
256 if !reflect.DeepEqual(res.Header, wantHeader) {
257 t.Errorf("%d: res Header = %v; want %v", i, res.Header, wantHeader)
258 }
259 if res.Request != req {
260 t.Errorf("%d: Response.Request = %p; want %p", i, res.Request, req)
261 }
262 if res.TLS == nil {
263 t.Errorf("%d: Response.TLS = nil; want non-nil", i)
264 }
265 slurp, err := io.ReadAll(res.Body)
266 if err != nil {
267 t.Errorf("%d: Body read: %v", i, err)
268 } else if string(slurp) != body {
269 t.Errorf("%d: Body = %q; want %q", i, slurp, body)
270 }
271 res.Body.Close()
272 }
273 }
274
275 func testTransportReusesConns(t *testing.T, useClient, wantSame bool, modReq func(*http.Request)) {
276 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
277 io.WriteString(w, r.RemoteAddr)
278 }, func(ts *httptest.Server) {
279 ts.Config.ConnState = func(c net.Conn, st http.ConnState) {
280 t.Logf("conn %v is now state %v", c.RemoteAddr(), st)
281 }
282 })
283 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
284 if useClient {
285 tr.ConnPool = noDialClientConnPool{new(clientConnPool)}
286 }
287 defer tr.CloseIdleConnections()
288 get := func() string {
289 req, err := http.NewRequest("GET", ts.URL, nil)
290 if err != nil {
291 t.Fatal(err)
292 }
293 modReq(req)
294 var res *http.Response
295 if useClient {
296 c := ts.Client()
297 ConfigureTransports(c.Transport.(*http.Transport))
298 res, err = c.Do(req)
299 } else {
300 res, err = tr.RoundTrip(req)
301 }
302 if err != nil {
303 t.Fatal(err)
304 }
305 defer res.Body.Close()
306 slurp, err := io.ReadAll(res.Body)
307 if err != nil {
308 t.Fatalf("Body read: %v", err)
309 }
310 addr := strings.TrimSpace(string(slurp))
311 if addr == "" {
312 t.Fatalf("didn't get an addr in response")
313 }
314 return addr
315 }
316 first := get()
317 second := get()
318 if got := first == second; got != wantSame {
319 t.Errorf("first and second responses on same connection: %v; want %v", got, wantSame)
320 }
321 }
322
323 func TestTransportReusesConns(t *testing.T) {
324 for _, test := range []struct {
325 name string
326 modReq func(*http.Request)
327 wantSame bool
328 }{{
329 name: "ReuseConn",
330 modReq: func(*http.Request) {},
331 wantSame: true,
332 }, {
333 name: "RequestClose",
334 modReq: func(r *http.Request) { r.Close = true },
335 wantSame: false,
336 }, {
337 name: "ConnClose",
338 modReq: func(r *http.Request) { r.Header.Set("Connection", "close") },
339 wantSame: false,
340 }} {
341 t.Run(test.name, func(t *testing.T) {
342 t.Run("Transport", func(t *testing.T) {
343 const useClient = false
344 testTransportReusesConns(t, useClient, test.wantSame, test.modReq)
345 })
346 t.Run("Client", func(t *testing.T) {
347 const useClient = true
348 testTransportReusesConns(t, useClient, test.wantSame, test.modReq)
349 })
350 })
351 }
352 }
353
354 func TestTransportGetGotConnHooks_HTTP2Transport(t *testing.T) {
355 testTransportGetGotConnHooks(t, false)
356 }
357 func TestTransportGetGotConnHooks_Client(t *testing.T) { testTransportGetGotConnHooks(t, true) }
358
359 func testTransportGetGotConnHooks(t *testing.T, useClient bool) {
360 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
361 io.WriteString(w, r.RemoteAddr)
362 })
363
364 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
365 client := ts.Client()
366 ConfigureTransports(client.Transport.(*http.Transport))
367
368 var (
369 getConns int32
370 gotConns int32
371 )
372 for i := 0; i < 2; i++ {
373 trace := &httptrace.ClientTrace{
374 GetConn: func(hostport string) {
375 atomic.AddInt32(&getConns, 1)
376 },
377 GotConn: func(connInfo httptrace.GotConnInfo) {
378 got := atomic.AddInt32(&gotConns, 1)
379 wantReused, wantWasIdle := false, false
380 if got > 1 {
381 wantReused, wantWasIdle = true, true
382 }
383 if connInfo.Reused != wantReused || connInfo.WasIdle != wantWasIdle {
384 t.Errorf("GotConn %v: Reused=%v (want %v), WasIdle=%v (want %v)", i, connInfo.Reused, wantReused, connInfo.WasIdle, wantWasIdle)
385 }
386 },
387 }
388 req, err := http.NewRequest("GET", ts.URL, nil)
389 if err != nil {
390 t.Fatal(err)
391 }
392 req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
393
394 var res *http.Response
395 if useClient {
396 res, err = client.Do(req)
397 } else {
398 res, err = tr.RoundTrip(req)
399 }
400 if err != nil {
401 t.Fatal(err)
402 }
403 res.Body.Close()
404 if get := atomic.LoadInt32(&getConns); get != int32(i+1) {
405 t.Errorf("after request %v, %v calls to GetConns: want %v", i, get, i+1)
406 }
407 if got := atomic.LoadInt32(&gotConns); got != int32(i+1) {
408 t.Errorf("after request %v, %v calls to GotConns: want %v", i, got, i+1)
409 }
410 }
411 }
412
413 type testNetConn struct {
414 net.Conn
415 closed bool
416 onClose func()
417 }
418
419 func (c *testNetConn) Close() error {
420 if !c.closed {
421
422 c.onClose()
423 }
424 c.closed = true
425 return c.Conn.Close()
426 }
427
428
429
430 func TestTransportGroupsPendingDials(t *testing.T) {
431 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
432 })
433 var (
434 mu sync.Mutex
435 dialCount int
436 closeCount int
437 )
438 tr := &Transport{
439 TLSClientConfig: tlsConfigInsecure,
440 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
441 mu.Lock()
442 dialCount++
443 mu.Unlock()
444 c, err := tls.Dial(network, addr, cfg)
445 return &testNetConn{
446 Conn: c,
447 onClose: func() {
448 mu.Lock()
449 closeCount++
450 mu.Unlock()
451 },
452 }, err
453 },
454 }
455 defer tr.CloseIdleConnections()
456 var wg sync.WaitGroup
457 for i := 0; i < 10; i++ {
458 wg.Add(1)
459 go func() {
460 defer wg.Done()
461 req, err := http.NewRequest("GET", ts.URL, nil)
462 if err != nil {
463 t.Error(err)
464 return
465 }
466 res, err := tr.RoundTrip(req)
467 if err != nil {
468 t.Error(err)
469 return
470 }
471 res.Body.Close()
472 }()
473 }
474 wg.Wait()
475 tr.CloseIdleConnections()
476 if dialCount != 1 {
477 t.Errorf("saw %d dials; want 1", dialCount)
478 }
479 if closeCount != 1 {
480 t.Errorf("saw %d closes; want 1", closeCount)
481 }
482 }
483
484 func TestTransportAbortClosesPipes(t *testing.T) {
485 shutdown := make(chan struct{})
486 ts := newTestServer(t,
487 func(w http.ResponseWriter, r *http.Request) {
488 w.(http.Flusher).Flush()
489 <-shutdown
490 },
491 )
492 defer close(shutdown)
493
494 errCh := make(chan error)
495 go func() {
496 defer close(errCh)
497 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
498 req, err := http.NewRequest("GET", ts.URL, nil)
499 if err != nil {
500 errCh <- err
501 return
502 }
503 res, err := tr.RoundTrip(req)
504 if err != nil {
505 errCh <- err
506 return
507 }
508 defer res.Body.Close()
509 ts.CloseClientConnections()
510 _, err = io.ReadAll(res.Body)
511 if err == nil {
512 errCh <- errors.New("expected error from res.Body.Read")
513 return
514 }
515 }()
516
517 select {
518 case err := <-errCh:
519 if err != nil {
520 t.Fatal(err)
521 }
522
523 case <-time.After(3 * time.Second):
524 t.Fatal("timeout")
525 }
526 }
527
528
529
530 func TestTransportPath(t *testing.T) {
531 gotc := make(chan *url.URL, 1)
532 ts := newTestServer(t,
533 func(w http.ResponseWriter, r *http.Request) {
534 gotc <- r.URL
535 },
536 )
537
538 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
539 defer tr.CloseIdleConnections()
540 const (
541 path = "/testpath"
542 query = "q=1"
543 )
544 surl := ts.URL + path + "?" + query
545 req, err := http.NewRequest("POST", surl, nil)
546 if err != nil {
547 t.Fatal(err)
548 }
549 c := &http.Client{Transport: tr}
550 res, err := c.Do(req)
551 if err != nil {
552 t.Fatal(err)
553 }
554 defer res.Body.Close()
555 got := <-gotc
556 if got.Path != path {
557 t.Errorf("Read Path = %q; want %q", got.Path, path)
558 }
559 if got.RawQuery != query {
560 t.Errorf("Read RawQuery = %q; want %q", got.RawQuery, query)
561 }
562 }
563
564 func randString(n int) string {
565 rnd := rand.New(rand.NewSource(int64(n)))
566 b := make([]byte, n)
567 for i := range b {
568 b[i] = byte(rnd.Intn(256))
569 }
570 return string(b)
571 }
572
573 type panicReader struct{}
574
575 func (panicReader) Read([]byte) (int, error) { panic("unexpected Read") }
576 func (panicReader) Close() error { panic("unexpected Close") }
577
578 func TestActualContentLength(t *testing.T) {
579 tests := []struct {
580 req *http.Request
581 want int64
582 }{
583
584 0: {
585 req: &http.Request{Body: panicReader{}},
586 want: -1,
587 },
588
589 1: {
590 req: &http.Request{Body: nil, ContentLength: 5},
591 want: 0,
592 },
593
594 2: {
595 req: &http.Request{Body: panicReader{}, ContentLength: 5},
596 want: 5,
597 },
598
599 3: {
600 req: &http.Request{Body: http.NoBody},
601 want: 0,
602 },
603 }
604 for i, tt := range tests {
605 got := actualContentLength(tt.req)
606 if got != tt.want {
607 t.Errorf("test[%d]: got %d; want %d", i, got, tt.want)
608 }
609 }
610 }
611
612 func TestTransportBody(t *testing.T) {
613 bodyTests := []struct {
614 body string
615 noContentLen bool
616 }{
617 {body: "some message"},
618 {body: "some message", noContentLen: true},
619 {body: strings.Repeat("a", 1<<20), noContentLen: true},
620 {body: strings.Repeat("a", 1<<20)},
621 {body: randString(16<<10 - 1)},
622 {body: randString(16 << 10)},
623 {body: randString(16<<10 + 1)},
624 {body: randString(512<<10 - 1)},
625 {body: randString(512 << 10)},
626 {body: randString(512<<10 + 1)},
627 {body: randString(1<<20 - 1)},
628 {body: randString(1 << 20)},
629 {body: randString(1<<20 + 2)},
630 }
631
632 type reqInfo struct {
633 req *http.Request
634 slurp []byte
635 err error
636 }
637 gotc := make(chan reqInfo, 1)
638 ts := newTestServer(t,
639 func(w http.ResponseWriter, r *http.Request) {
640 slurp, err := io.ReadAll(r.Body)
641 if err != nil {
642 gotc <- reqInfo{err: err}
643 } else {
644 gotc <- reqInfo{req: r, slurp: slurp}
645 }
646 },
647 )
648
649 for i, tt := range bodyTests {
650 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
651 defer tr.CloseIdleConnections()
652
653 var body io.Reader = strings.NewReader(tt.body)
654 if tt.noContentLen {
655 body = struct{ io.Reader }{body}
656 }
657 req, err := http.NewRequest("POST", ts.URL, body)
658 if err != nil {
659 t.Fatalf("#%d: %v", i, err)
660 }
661 c := &http.Client{Transport: tr}
662 res, err := c.Do(req)
663 if err != nil {
664 t.Fatalf("#%d: %v", i, err)
665 }
666 defer res.Body.Close()
667 ri := <-gotc
668 if ri.err != nil {
669 t.Errorf("#%d: read error: %v", i, ri.err)
670 continue
671 }
672 if got := string(ri.slurp); got != tt.body {
673 t.Errorf("#%d: Read body mismatch.\n got: %q (len %d)\nwant: %q (len %d)", i, shortString(got), len(got), shortString(tt.body), len(tt.body))
674 }
675 wantLen := int64(len(tt.body))
676 if tt.noContentLen && tt.body != "" {
677 wantLen = -1
678 }
679 if ri.req.ContentLength != wantLen {
680 t.Errorf("#%d. handler got ContentLength = %v; want %v", i, ri.req.ContentLength, wantLen)
681 }
682 }
683 }
684
685 func shortString(v string) string {
686 const maxLen = 100
687 if len(v) <= maxLen {
688 return v
689 }
690 return fmt.Sprintf("%v[...%d bytes omitted...]%v", v[:maxLen/2], len(v)-maxLen, v[len(v)-maxLen/2:])
691 }
692
693 func TestTransportDialTLS(t *testing.T) {
694 var mu sync.Mutex
695 var gotReq, didDial bool
696
697 ts := newTestServer(t,
698 func(w http.ResponseWriter, r *http.Request) {
699 mu.Lock()
700 gotReq = true
701 mu.Unlock()
702 },
703 )
704 tr := &Transport{
705 DialTLS: func(netw, addr string, cfg *tls.Config) (net.Conn, error) {
706 mu.Lock()
707 didDial = true
708 mu.Unlock()
709 cfg.InsecureSkipVerify = true
710 c, err := tls.Dial(netw, addr, cfg)
711 if err != nil {
712 return nil, err
713 }
714 return c, c.Handshake()
715 },
716 }
717 defer tr.CloseIdleConnections()
718 client := &http.Client{Transport: tr}
719 res, err := client.Get(ts.URL)
720 if err != nil {
721 t.Fatal(err)
722 }
723 res.Body.Close()
724 mu.Lock()
725 if !gotReq {
726 t.Error("didn't get request")
727 }
728 if !didDial {
729 t.Error("didn't use dial hook")
730 }
731 }
732
733 func TestConfigureTransport(t *testing.T) {
734 t1 := &http.Transport{}
735 err := ConfigureTransport(t1)
736 if err != nil {
737 t.Fatal(err)
738 }
739 if got := fmt.Sprintf("%#v", t1); !strings.Contains(got, `"h2"`) {
740
741 t.Errorf("stringification of HTTP/1 transport didn't contain \"h2\": %v", got)
742 }
743 wantNextProtos := []string{"h2", "http/1.1"}
744 if t1.TLSClientConfig == nil {
745 t.Errorf("nil t1.TLSClientConfig")
746 } else if !reflect.DeepEqual(t1.TLSClientConfig.NextProtos, wantNextProtos) {
747 t.Errorf("TLSClientConfig.NextProtos = %q; want %q", t1.TLSClientConfig.NextProtos, wantNextProtos)
748 }
749 if err := ConfigureTransport(t1); err == nil {
750 t.Error("unexpected success on second call to ConfigureTransport")
751 }
752
753
754 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
755 io.WriteString(w, r.Proto)
756 })
757
758 t1.TLSClientConfig.InsecureSkipVerify = true
759 c := &http.Client{Transport: t1}
760 res, err := c.Get(ts.URL)
761 if err != nil {
762 t.Fatal(err)
763 }
764 slurp, err := io.ReadAll(res.Body)
765 if err != nil {
766 t.Fatal(err)
767 }
768 if got, want := string(slurp), "HTTP/2.0"; got != want {
769 t.Errorf("body = %q; want %q", got, want)
770 }
771 }
772
773 type capitalizeReader struct {
774 r io.Reader
775 }
776
777 func (cr capitalizeReader) Read(p []byte) (n int, err error) {
778 n, err = cr.r.Read(p)
779 for i, b := range p[:n] {
780 if b >= 'a' && b <= 'z' {
781 p[i] = b - ('a' - 'A')
782 }
783 }
784 return
785 }
786
787 type flushWriter struct {
788 w io.Writer
789 }
790
791 func (fw flushWriter) Write(p []byte) (n int, err error) {
792 n, err = fw.w.Write(p)
793 if f, ok := fw.w.(http.Flusher); ok {
794 f.Flush()
795 }
796 return
797 }
798
799 func newLocalListener(t *testing.T) net.Listener {
800 ln, err := net.Listen("tcp4", "127.0.0.1:0")
801 if err == nil {
802 return ln
803 }
804 ln, err = net.Listen("tcp6", "[::1]:0")
805 if err != nil {
806 t.Fatal(err)
807 }
808 return ln
809 }
810
811 func TestTransportReqBodyAfterResponse_200(t *testing.T) { testTransportReqBodyAfterResponse(t, 200) }
812 func TestTransportReqBodyAfterResponse_403(t *testing.T) { testTransportReqBodyAfterResponse(t, 403) }
813
814 func testTransportReqBodyAfterResponse(t *testing.T, status int) {
815 const bodySize = 1 << 10
816
817 tc := newTestClientConn(t)
818 tc.greet()
819
820 body := tc.newRequestBody()
821 body.writeBytes(bodySize / 2)
822 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
823 rt := tc.roundTrip(req)
824
825 tc.wantHeaders(wantHeader{
826 streamID: rt.streamID(),
827 endStream: false,
828 header: http.Header{
829 ":authority": []string{"dummy.tld"},
830 ":method": []string{"PUT"},
831 ":path": []string{"/"},
832 },
833 })
834
835
836 tc.writeWindowUpdate(0, bodySize)
837 tc.writeWindowUpdate(rt.streamID(), bodySize)
838
839 tc.wantData(wantData{
840 streamID: rt.streamID(),
841 endStream: false,
842 size: bodySize / 2,
843 })
844
845 tc.writeHeaders(HeadersFrameParam{
846 StreamID: rt.streamID(),
847 EndHeaders: true,
848 EndStream: true,
849 BlockFragment: tc.makeHeaderBlockFragment(
850 ":status", strconv.Itoa(status),
851 ),
852 })
853
854 res := rt.response()
855 if res.StatusCode != status {
856 t.Fatalf("status code = %v; want %v", res.StatusCode, status)
857 }
858
859 body.writeBytes(bodySize / 2)
860 body.closeWithError(io.EOF)
861
862 if status == 200 {
863
864 tc.wantData(wantData{
865 streamID: rt.streamID(),
866 endStream: true,
867 size: bodySize / 2,
868 multiple: true,
869 })
870 } else {
871
872 tc.wantFrameType(FrameRSTStream)
873 }
874
875 rt.wantBody(nil)
876 }
877
878
879 func TestTransportFullDuplex(t *testing.T) {
880 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
881 w.WriteHeader(200)
882 w.(http.Flusher).Flush()
883 io.Copy(flushWriter{w}, capitalizeReader{r.Body})
884 fmt.Fprintf(w, "bye.\n")
885 })
886
887 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
888 defer tr.CloseIdleConnections()
889 c := &http.Client{Transport: tr}
890
891 pr, pw := io.Pipe()
892 req, err := http.NewRequest("PUT", ts.URL, io.NopCloser(pr))
893 if err != nil {
894 t.Fatal(err)
895 }
896 req.ContentLength = -1
897 res, err := c.Do(req)
898 if err != nil {
899 t.Fatal(err)
900 }
901 defer res.Body.Close()
902 if res.StatusCode != 200 {
903 t.Fatalf("StatusCode = %v; want %v", res.StatusCode, 200)
904 }
905 bs := bufio.NewScanner(res.Body)
906 want := func(v string) {
907 if !bs.Scan() {
908 t.Fatalf("wanted to read %q but Scan() = false, err = %v", v, bs.Err())
909 }
910 }
911 write := func(v string) {
912 _, err := io.WriteString(pw, v)
913 if err != nil {
914 t.Fatalf("pipe write: %v", err)
915 }
916 }
917 write("foo\n")
918 want("FOO")
919 write("bar\n")
920 want("BAR")
921 pw.Close()
922 want("bye.")
923 if err := bs.Err(); err != nil {
924 t.Fatal(err)
925 }
926 }
927
928 func TestTransportConnectRequest(t *testing.T) {
929 gotc := make(chan *http.Request, 1)
930 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
931 gotc <- r
932 })
933
934 u, err := url.Parse(ts.URL)
935 if err != nil {
936 t.Fatal(err)
937 }
938
939 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
940 defer tr.CloseIdleConnections()
941 c := &http.Client{Transport: tr}
942
943 tests := []struct {
944 req *http.Request
945 want string
946 }{
947 {
948 req: &http.Request{
949 Method: "CONNECT",
950 Header: http.Header{},
951 URL: u,
952 },
953 want: u.Host,
954 },
955 {
956 req: &http.Request{
957 Method: "CONNECT",
958 Header: http.Header{},
959 URL: u,
960 Host: "example.com:123",
961 },
962 want: "example.com:123",
963 },
964 }
965
966 for i, tt := range tests {
967 res, err := c.Do(tt.req)
968 if err != nil {
969 t.Errorf("%d. RoundTrip = %v", i, err)
970 continue
971 }
972 res.Body.Close()
973 req := <-gotc
974 if req.Method != "CONNECT" {
975 t.Errorf("method = %q; want CONNECT", req.Method)
976 }
977 if req.Host != tt.want {
978 t.Errorf("Host = %q; want %q", req.Host, tt.want)
979 }
980 if req.URL.Host != tt.want {
981 t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
982 }
983 }
984 }
985
986 type headerType int
987
988 const (
989 noHeader headerType = iota
990 oneHeader
991 splitHeader
992 )
993
994 const (
995 f0 = noHeader
996 f1 = oneHeader
997 f2 = splitHeader
998 d0 = false
999 d1 = true
1000 )
1001
1002
1003
1004
1005
1006
1007 func TestTransportResPattern_c0h1d0t0(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f0) }
1008 func TestTransportResPattern_c0h1d0t1(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f1) }
1009 func TestTransportResPattern_c0h1d0t2(t *testing.T) { testTransportResPattern(t, f0, f1, d0, f2) }
1010 func TestTransportResPattern_c0h1d1t0(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f0) }
1011 func TestTransportResPattern_c0h1d1t1(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f1) }
1012 func TestTransportResPattern_c0h1d1t2(t *testing.T) { testTransportResPattern(t, f0, f1, d1, f2) }
1013 func TestTransportResPattern_c0h2d0t0(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f0) }
1014 func TestTransportResPattern_c0h2d0t1(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f1) }
1015 func TestTransportResPattern_c0h2d0t2(t *testing.T) { testTransportResPattern(t, f0, f2, d0, f2) }
1016 func TestTransportResPattern_c0h2d1t0(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f0) }
1017 func TestTransportResPattern_c0h2d1t1(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f1) }
1018 func TestTransportResPattern_c0h2d1t2(t *testing.T) { testTransportResPattern(t, f0, f2, d1, f2) }
1019 func TestTransportResPattern_c1h1d0t0(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f0) }
1020 func TestTransportResPattern_c1h1d0t1(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f1) }
1021 func TestTransportResPattern_c1h1d0t2(t *testing.T) { testTransportResPattern(t, f1, f1, d0, f2) }
1022 func TestTransportResPattern_c1h1d1t0(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f0) }
1023 func TestTransportResPattern_c1h1d1t1(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f1) }
1024 func TestTransportResPattern_c1h1d1t2(t *testing.T) { testTransportResPattern(t, f1, f1, d1, f2) }
1025 func TestTransportResPattern_c1h2d0t0(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f0) }
1026 func TestTransportResPattern_c1h2d0t1(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f1) }
1027 func TestTransportResPattern_c1h2d0t2(t *testing.T) { testTransportResPattern(t, f1, f2, d0, f2) }
1028 func TestTransportResPattern_c1h2d1t0(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f0) }
1029 func TestTransportResPattern_c1h2d1t1(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f1) }
1030 func TestTransportResPattern_c1h2d1t2(t *testing.T) { testTransportResPattern(t, f1, f2, d1, f2) }
1031 func TestTransportResPattern_c2h1d0t0(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f0) }
1032 func TestTransportResPattern_c2h1d0t1(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f1) }
1033 func TestTransportResPattern_c2h1d0t2(t *testing.T) { testTransportResPattern(t, f2, f1, d0, f2) }
1034 func TestTransportResPattern_c2h1d1t0(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f0) }
1035 func TestTransportResPattern_c2h1d1t1(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f1) }
1036 func TestTransportResPattern_c2h1d1t2(t *testing.T) { testTransportResPattern(t, f2, f1, d1, f2) }
1037 func TestTransportResPattern_c2h2d0t0(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f0) }
1038 func TestTransportResPattern_c2h2d0t1(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f1) }
1039 func TestTransportResPattern_c2h2d0t2(t *testing.T) { testTransportResPattern(t, f2, f2, d0, f2) }
1040 func TestTransportResPattern_c2h2d1t0(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f0) }
1041 func TestTransportResPattern_c2h2d1t1(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f1) }
1042 func TestTransportResPattern_c2h2d1t2(t *testing.T) { testTransportResPattern(t, f2, f2, d1, f2) }
1043
1044 func testTransportResPattern(t *testing.T, expect100Continue, resHeader headerType, withData bool, trailers headerType) {
1045 const reqBody = "some request body"
1046 const resBody = "some response body"
1047
1048 if resHeader == noHeader {
1049
1050
1051 panic("invalid combination")
1052 }
1053
1054 tc := newTestClientConn(t)
1055 tc.greet()
1056
1057 req, _ := http.NewRequest("POST", "https://dummy.tld/", strings.NewReader(reqBody))
1058 if expect100Continue != noHeader {
1059 req.Header.Set("Expect", "100-continue")
1060 }
1061 rt := tc.roundTrip(req)
1062
1063 tc.wantFrameType(FrameHeaders)
1064
1065
1066 tc.writeHeadersMode(expect100Continue, HeadersFrameParam{
1067 StreamID: rt.streamID(),
1068 EndHeaders: true,
1069 EndStream: false,
1070 BlockFragment: tc.makeHeaderBlockFragment(
1071 ":status", "100",
1072 ),
1073 })
1074
1075
1076 tc.wantData(wantData{
1077 streamID: rt.streamID(),
1078 endStream: true,
1079 size: len(reqBody),
1080 })
1081
1082 hdr := []string{
1083 ":status", "200",
1084 "x-foo", "blah",
1085 "x-bar", "more",
1086 }
1087 if trailers != noHeader {
1088 hdr = append(hdr, "trailer", "some-trailer")
1089 }
1090 tc.writeHeadersMode(resHeader, HeadersFrameParam{
1091 StreamID: rt.streamID(),
1092 EndHeaders: true,
1093 EndStream: withData == false && trailers == noHeader,
1094 BlockFragment: tc.makeHeaderBlockFragment(hdr...),
1095 })
1096 if withData {
1097 endStream := trailers == noHeader
1098 tc.writeData(rt.streamID(), endStream, []byte(resBody))
1099 }
1100 tc.writeHeadersMode(trailers, HeadersFrameParam{
1101 StreamID: rt.streamID(),
1102 EndHeaders: true,
1103 EndStream: true,
1104 BlockFragment: tc.makeHeaderBlockFragment(
1105 "some-trailer", "some-value",
1106 ),
1107 })
1108
1109 rt.wantStatus(200)
1110 if !withData {
1111 rt.wantBody(nil)
1112 } else {
1113 rt.wantBody([]byte(resBody))
1114 }
1115 if trailers == noHeader {
1116 rt.wantTrailers(nil)
1117 } else {
1118 rt.wantTrailers(http.Header{
1119 "Some-Trailer": {"some-value"},
1120 })
1121 }
1122 }
1123
1124
1125 func TestTransportUnknown1xx(t *testing.T) {
1126 var buf bytes.Buffer
1127 defer func() { got1xxFuncForTests = nil }()
1128 got1xxFuncForTests = func(code int, header textproto.MIMEHeader) error {
1129 fmt.Fprintf(&buf, "code=%d header=%v\n", code, header)
1130 return nil
1131 }
1132
1133 tc := newTestClientConn(t)
1134 tc.greet()
1135
1136 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1137 rt := tc.roundTrip(req)
1138
1139 for i := 110; i <= 114; i++ {
1140 tc.writeHeaders(HeadersFrameParam{
1141 StreamID: rt.streamID(),
1142 EndHeaders: true,
1143 EndStream: false,
1144 BlockFragment: tc.makeHeaderBlockFragment(
1145 ":status", fmt.Sprint(i),
1146 "foo-bar", fmt.Sprint(i),
1147 ),
1148 })
1149 }
1150 tc.writeHeaders(HeadersFrameParam{
1151 StreamID: rt.streamID(),
1152 EndHeaders: true,
1153 EndStream: true,
1154 BlockFragment: tc.makeHeaderBlockFragment(
1155 ":status", "204",
1156 ),
1157 })
1158
1159 res := rt.response()
1160 if res.StatusCode != 204 {
1161 t.Fatalf("status code = %v; want 204", res.StatusCode)
1162 }
1163 want := `code=110 header=map[Foo-Bar:[110]]
1164 code=111 header=map[Foo-Bar:[111]]
1165 code=112 header=map[Foo-Bar:[112]]
1166 code=113 header=map[Foo-Bar:[113]]
1167 code=114 header=map[Foo-Bar:[114]]
1168 `
1169 if got := buf.String(); got != want {
1170 t.Errorf("Got trace:\n%s\nWant:\n%s", got, want)
1171 }
1172 }
1173
1174 func TestTransportReceiveUndeclaredTrailer(t *testing.T) {
1175 tc := newTestClientConn(t)
1176 tc.greet()
1177
1178 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1179 rt := tc.roundTrip(req)
1180
1181 tc.writeHeaders(HeadersFrameParam{
1182 StreamID: rt.streamID(),
1183 EndHeaders: true,
1184 EndStream: false,
1185 BlockFragment: tc.makeHeaderBlockFragment(
1186 ":status", "200",
1187 ),
1188 })
1189 tc.writeHeaders(HeadersFrameParam{
1190 StreamID: rt.streamID(),
1191 EndHeaders: true,
1192 EndStream: true,
1193 BlockFragment: tc.makeHeaderBlockFragment(
1194 "some-trailer", "I'm an undeclared Trailer!",
1195 ),
1196 })
1197
1198 rt.wantStatus(200)
1199 rt.wantBody(nil)
1200 rt.wantTrailers(http.Header{
1201 "Some-Trailer": []string{"I'm an undeclared Trailer!"},
1202 })
1203 }
1204
1205 func TestTransportInvalidTrailer_Pseudo1(t *testing.T) {
1206 testTransportInvalidTrailer_Pseudo(t, oneHeader)
1207 }
1208 func TestTransportInvalidTrailer_Pseudo2(t *testing.T) {
1209 testTransportInvalidTrailer_Pseudo(t, splitHeader)
1210 }
1211 func testTransportInvalidTrailer_Pseudo(t *testing.T, trailers headerType) {
1212 testInvalidTrailer(t, trailers, pseudoHeaderError(":colon"),
1213 ":colon", "foo",
1214 "foo", "bar",
1215 )
1216 }
1217
1218 func TestTransportInvalidTrailer_Capital1(t *testing.T) {
1219 testTransportInvalidTrailer_Capital(t, oneHeader)
1220 }
1221 func TestTransportInvalidTrailer_Capital2(t *testing.T) {
1222 testTransportInvalidTrailer_Capital(t, splitHeader)
1223 }
1224 func testTransportInvalidTrailer_Capital(t *testing.T, trailers headerType) {
1225 testInvalidTrailer(t, trailers, headerFieldNameError("Capital"),
1226 "foo", "bar",
1227 "Capital", "bad",
1228 )
1229 }
1230 func TestTransportInvalidTrailer_EmptyFieldName(t *testing.T) {
1231 testInvalidTrailer(t, oneHeader, headerFieldNameError(""),
1232 "", "bad",
1233 )
1234 }
1235 func TestTransportInvalidTrailer_BinaryFieldValue(t *testing.T) {
1236 testInvalidTrailer(t, oneHeader, headerFieldValueError("x"),
1237 "x", "has\nnewline",
1238 )
1239 }
1240
1241 func testInvalidTrailer(t *testing.T, mode headerType, wantErr error, trailers ...string) {
1242 tc := newTestClientConn(t)
1243 tc.greet()
1244
1245 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1246 rt := tc.roundTrip(req)
1247
1248 tc.writeHeaders(HeadersFrameParam{
1249 StreamID: rt.streamID(),
1250 EndHeaders: true,
1251 EndStream: false,
1252 BlockFragment: tc.makeHeaderBlockFragment(
1253 ":status", "200",
1254 "trailer", "declared",
1255 ),
1256 })
1257 tc.writeHeadersMode(mode, HeadersFrameParam{
1258 StreamID: rt.streamID(),
1259 EndHeaders: true,
1260 EndStream: true,
1261 BlockFragment: tc.makeHeaderBlockFragment(trailers...),
1262 })
1263
1264 rt.wantStatus(200)
1265 body, err := rt.readBody()
1266 se, ok := err.(StreamError)
1267 if !ok || se.Cause != wantErr {
1268 t.Fatalf("res.Body ReadAll error = %q, %#v; want StreamError with cause %T, %#v", body, err, wantErr, wantErr)
1269 }
1270 if len(body) > 0 {
1271 t.Fatalf("body = %q; want nothing", body)
1272 }
1273 }
1274
1275
1276
1277
1278
1279 func headerListSize(h http.Header) (size uint32) {
1280 for k, vv := range h {
1281 for _, v := range vv {
1282 hf := hpack.HeaderField{Name: k, Value: v}
1283 size += hf.Size()
1284 }
1285 }
1286 return size
1287 }
1288
1289
1290
1291
1292
1293
1294
1295
1296 func padHeaders(t *testing.T, h http.Header, limit uint64, filler string) {
1297 if limit > 0xffffffff {
1298 t.Fatalf("padHeaders: refusing to pad to more than 2^32-1 bytes. limit = %v", limit)
1299 }
1300 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1301 minPadding := uint64(hf.Size())
1302 size := uint64(headerListSize(h))
1303
1304 minlimit := size + minPadding
1305 if limit < minlimit {
1306 t.Fatalf("padHeaders: limit %v < %v", limit, minlimit)
1307 }
1308
1309
1310
1311 nameFmt := "Pad-Headers-%06d"
1312 hf = hpack.HeaderField{Name: fmt.Sprintf(nameFmt, 1), Value: filler}
1313 fieldSize := uint64(hf.Size())
1314
1315
1316
1317 limit = limit - minPadding
1318 for i := 0; size+fieldSize < limit; i++ {
1319 name := fmt.Sprintf(nameFmt, i)
1320 h.Add(name, filler)
1321 size += fieldSize
1322 }
1323
1324
1325 remain := limit - size
1326 lastValue := strings.Repeat("*", int(remain))
1327 h.Add("Pad-Headers", lastValue)
1328 }
1329
1330 func TestPadHeaders(t *testing.T) {
1331 check := func(h http.Header, limit uint32, fillerLen int) {
1332 if h == nil {
1333 h = make(http.Header)
1334 }
1335 filler := strings.Repeat("f", fillerLen)
1336 padHeaders(t, h, uint64(limit), filler)
1337 gotSize := headerListSize(h)
1338 if gotSize != limit {
1339 t.Errorf("Got size = %v; want %v", gotSize, limit)
1340 }
1341 }
1342
1343 hf := hpack.HeaderField{Name: "Pad-Headers", Value: ""}
1344 minLimit := hf.Size()
1345 for limit := minLimit; limit <= 128; limit++ {
1346 for fillerLen := 0; uint32(fillerLen) <= limit; fillerLen++ {
1347 check(nil, limit, fillerLen)
1348 }
1349 }
1350
1351
1352
1353
1354
1355
1356 tests := []struct {
1357 fillerLen int
1358 limit uint32
1359 }{
1360 {
1361 fillerLen: 64,
1362 limit: 1024,
1363 },
1364 {
1365 fillerLen: 1024,
1366 limit: 1286,
1367 },
1368 {
1369 fillerLen: 256,
1370 limit: 2048,
1371 },
1372 {
1373 fillerLen: 1024,
1374 limit: 10 * 1024,
1375 },
1376 {
1377 fillerLen: 1023,
1378 limit: 11 * 1024,
1379 },
1380 }
1381 h := make(http.Header)
1382 for _, tc := range tests {
1383 check(nil, tc.limit, tc.fillerLen)
1384 check(h, tc.limit, tc.fillerLen)
1385 }
1386 }
1387
1388 func TestTransportChecksRequestHeaderListSize(t *testing.T) {
1389 ts := newTestServer(t,
1390 func(w http.ResponseWriter, r *http.Request) {
1391
1392
1393
1394
1395
1396
1397 io.ReadAll(r.Body)
1398 r.Body.Close()
1399 },
1400 func(ts *httptest.Server) {
1401 ts.Config.MaxHeaderBytes = 16 << 10
1402 },
1403 optQuiet,
1404 )
1405
1406 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1407 defer tr.CloseIdleConnections()
1408
1409 checkRoundTrip := func(req *http.Request, wantErr error, desc string) {
1410
1411
1412 req0, err := http.NewRequest("GET", ts.URL, nil)
1413 if err != nil {
1414 t.Fatalf("newRequest: NewRequest: %v", err)
1415 }
1416 res0, err := tr.RoundTrip(req0)
1417 if err != nil {
1418 t.Errorf("%v: Initial RoundTrip err = %v", desc, err)
1419 }
1420 res0.Body.Close()
1421
1422 res, err := tr.RoundTrip(req)
1423 if err != wantErr {
1424 if res != nil {
1425 res.Body.Close()
1426 }
1427 t.Errorf("%v: RoundTrip err = %v; want %v", desc, err, wantErr)
1428 return
1429 }
1430 if err == nil {
1431 if res == nil {
1432 t.Errorf("%v: response nil; want non-nil.", desc)
1433 return
1434 }
1435 defer res.Body.Close()
1436 if res.StatusCode != http.StatusOK {
1437 t.Errorf("%v: response status = %v; want %v", desc, res.StatusCode, http.StatusOK)
1438 }
1439 return
1440 }
1441 if res != nil {
1442 t.Errorf("%v: RoundTrip err = %v but response non-nil", desc, err)
1443 }
1444 }
1445 headerListSizeForRequest := func(req *http.Request) (size uint64) {
1446 contentLen := actualContentLength(req)
1447 trailers, err := commaSeparatedTrailers(req)
1448 if err != nil {
1449 t.Fatalf("headerListSizeForRequest: %v", err)
1450 }
1451 cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
1452 cc.henc = hpack.NewEncoder(&cc.hbuf)
1453 cc.mu.Lock()
1454 hdrs, err := cc.encodeHeaders(req, true, trailers, contentLen)
1455 cc.mu.Unlock()
1456 if err != nil {
1457 t.Fatalf("headerListSizeForRequest: %v", err)
1458 }
1459 hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(hf hpack.HeaderField) {
1460 size += uint64(hf.Size())
1461 })
1462 if len(hdrs) > 0 {
1463 if _, err := hpackDec.Write(hdrs); err != nil {
1464 t.Fatalf("headerListSizeForRequest: %v", err)
1465 }
1466 }
1467 return size
1468 }
1469
1470
1471
1472 newRequest := func() *http.Request {
1473
1474 body := strings.NewReader("hello")
1475 req, err := http.NewRequest("POST", ts.URL, body)
1476 if err != nil {
1477 t.Fatalf("newRequest: NewRequest: %v", err)
1478 }
1479 return req
1480 }
1481
1482 var (
1483 scMu sync.Mutex
1484 sc *serverConn
1485 )
1486 testHookGetServerConn = func(v *serverConn) {
1487 scMu.Lock()
1488 defer scMu.Unlock()
1489 if sc != nil {
1490 panic("testHookGetServerConn called multiple times")
1491 }
1492 sc = v
1493 }
1494 defer func() {
1495 testHookGetServerConn = nil
1496 }()
1497
1498
1499 req := newRequest()
1500 checkRoundTrip(req, nil, "Initial request")
1501 addr := authorityAddr(req.URL.Scheme, req.URL.Host)
1502 cc, err := tr.connPool().GetClientConn(req, addr)
1503 if err != nil {
1504 t.Fatalf("GetClientConn: %v", err)
1505 }
1506 cc.mu.Lock()
1507 peerSize := cc.peerMaxHeaderListSize
1508 cc.mu.Unlock()
1509 scMu.Lock()
1510 wantSize := uint64(sc.maxHeaderListSize())
1511 scMu.Unlock()
1512 if peerSize != wantSize {
1513 t.Errorf("peerMaxHeaderListSize = %v; want %v", peerSize, wantSize)
1514 }
1515
1516
1517
1518 wantHeaderBytes := uint64(ts.Config.MaxHeaderBytes) + 320
1519 if peerSize != wantHeaderBytes {
1520 t.Errorf("peerMaxHeaderListSize = %v; want %v.", peerSize, wantHeaderBytes)
1521 }
1522
1523
1524 req = newRequest()
1525 req.Header = make(http.Header)
1526 req.Trailer = make(http.Header)
1527 filler := strings.Repeat("*", 1024)
1528 padHeaders(t, req.Trailer, peerSize, filler)
1529
1530
1531 defaultBytes := headerListSizeForRequest(req)
1532 padHeaders(t, req.Header, peerSize-defaultBytes, filler)
1533 checkRoundTrip(req, nil, "Headers & Trailers under limit")
1534
1535
1536 req = newRequest()
1537 req.Header = make(http.Header)
1538 padHeaders(t, req.Header, peerSize, filler)
1539 checkRoundTrip(req, errRequestHeaderListSize, "Headers over limit")
1540
1541
1542 req = newRequest()
1543 req.Trailer = make(http.Header)
1544 padHeaders(t, req.Trailer, peerSize+1, filler)
1545 checkRoundTrip(req, errRequestHeaderListSize, "Trailers over limit")
1546
1547
1548 req = newRequest()
1549 filler = strings.Repeat("*", int(peerSize))
1550 req.Header = make(http.Header)
1551 req.Header.Set("Big", filler)
1552 checkRoundTrip(req, errRequestHeaderListSize, "Single large header")
1553
1554
1555 req = newRequest()
1556 req.Trailer = make(http.Header)
1557 req.Trailer.Set("Big", filler)
1558 checkRoundTrip(req, errRequestHeaderListSize, "Single large trailer")
1559 }
1560
1561 func TestTransportChecksResponseHeaderListSize(t *testing.T) {
1562 tc := newTestClientConn(t)
1563 tc.greet()
1564
1565 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1566 rt := tc.roundTrip(req)
1567
1568 tc.wantFrameType(FrameHeaders)
1569
1570 hdr := []string{":status", "200"}
1571 large := strings.Repeat("a", 1<<10)
1572 for i := 0; i < 5042; i++ {
1573 hdr = append(hdr, large, large)
1574 }
1575 hbf := tc.makeHeaderBlockFragment(hdr...)
1576
1577
1578
1579 if size, want := len(hbf), 6329; size != want {
1580 t.Fatalf("encoding over 10MB of duplicate keypairs took %d bytes; expected %d", size, want)
1581 }
1582 tc.writeHeaders(HeadersFrameParam{
1583 StreamID: rt.streamID(),
1584 EndHeaders: true,
1585 EndStream: true,
1586 BlockFragment: hbf,
1587 })
1588
1589 res, err := rt.result()
1590 if e, ok := err.(StreamError); ok {
1591 err = e.Cause
1592 }
1593 if err != errResponseHeaderListSize {
1594 size := int64(0)
1595 if res != nil {
1596 res.Body.Close()
1597 for k, vv := range res.Header {
1598 for _, v := range vv {
1599 size += int64(len(k)) + int64(len(v)) + 32
1600 }
1601 }
1602 }
1603 t.Fatalf("RoundTrip Error = %v (and %d bytes of response headers); want errResponseHeaderListSize", err, size)
1604 }
1605 }
1606
1607 func TestTransportCookieHeaderSplit(t *testing.T) {
1608 tc := newTestClientConn(t)
1609 tc.greet()
1610
1611 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
1612 req.Header.Add("Cookie", "a=b;c=d; e=f;")
1613 req.Header.Add("Cookie", "e=f;g=h; ")
1614 req.Header.Add("Cookie", "i=j")
1615 rt := tc.roundTrip(req)
1616
1617 tc.wantHeaders(wantHeader{
1618 streamID: rt.streamID(),
1619 endStream: true,
1620 header: http.Header{
1621 "cookie": []string{"a=b", "c=d", "e=f", "e=f", "g=h", "i=j"},
1622 },
1623 })
1624 tc.writeHeaders(HeadersFrameParam{
1625 StreamID: rt.streamID(),
1626 EndHeaders: true,
1627 EndStream: true,
1628 BlockFragment: tc.makeHeaderBlockFragment(
1629 ":status", "204",
1630 ),
1631 })
1632
1633 if err := rt.err(); err != nil {
1634 t.Fatalf("RoundTrip = %v, want success", err)
1635 }
1636 }
1637
1638
1639
1640
1641 func TestTransportBodyReadErrorType(t *testing.T) {
1642 doPanic := make(chan bool, 1)
1643 ts := newTestServer(t,
1644 func(w http.ResponseWriter, r *http.Request) {
1645 w.(http.Flusher).Flush()
1646 <-doPanic
1647 panic("boom")
1648 },
1649 optQuiet,
1650 )
1651
1652 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1653 defer tr.CloseIdleConnections()
1654 c := &http.Client{Transport: tr}
1655
1656 res, err := c.Get(ts.URL)
1657 if err != nil {
1658 t.Fatal(err)
1659 }
1660 defer res.Body.Close()
1661 doPanic <- true
1662 buf := make([]byte, 100)
1663 n, err := res.Body.Read(buf)
1664 got, ok := err.(StreamError)
1665 want := StreamError{StreamID: 0x1, Code: 0x2}
1666 if !ok || got.StreamID != want.StreamID || got.Code != want.Code {
1667 t.Errorf("Read = %v, %#v; want error %#v", n, err, want)
1668 }
1669 }
1670
1671
1672
1673
1674 func TestTransportDoubleCloseOnWriteError(t *testing.T) {
1675 var (
1676 mu sync.Mutex
1677 conn net.Conn
1678 )
1679
1680 ts := newTestServer(t,
1681 func(w http.ResponseWriter, r *http.Request) {
1682 mu.Lock()
1683 defer mu.Unlock()
1684 if conn != nil {
1685 conn.Close()
1686 }
1687 },
1688 )
1689
1690 tr := &Transport{
1691 TLSClientConfig: tlsConfigInsecure,
1692 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
1693 tc, err := tls.Dial(network, addr, cfg)
1694 if err != nil {
1695 return nil, err
1696 }
1697 mu.Lock()
1698 defer mu.Unlock()
1699 conn = tc
1700 return tc, nil
1701 },
1702 }
1703 defer tr.CloseIdleConnections()
1704 c := &http.Client{Transport: tr}
1705 c.Get(ts.URL)
1706 }
1707
1708
1709
1710
1711 func TestTransportDisableKeepAlives(t *testing.T) {
1712 ts := newTestServer(t,
1713 func(w http.ResponseWriter, r *http.Request) {
1714 io.WriteString(w, "hi")
1715 },
1716 )
1717
1718 connClosed := make(chan struct{})
1719 tr := &Transport{
1720 t1: &http.Transport{
1721 DisableKeepAlives: true,
1722 },
1723 TLSClientConfig: tlsConfigInsecure,
1724 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
1725 tc, err := tls.Dial(network, addr, cfg)
1726 if err != nil {
1727 return nil, err
1728 }
1729 return ¬eCloseConn{Conn: tc, closefn: func() { close(connClosed) }}, nil
1730 },
1731 }
1732 c := &http.Client{Transport: tr}
1733 res, err := c.Get(ts.URL)
1734 if err != nil {
1735 t.Fatal(err)
1736 }
1737 if _, err := io.ReadAll(res.Body); err != nil {
1738 t.Fatal(err)
1739 }
1740 defer res.Body.Close()
1741
1742 select {
1743 case <-connClosed:
1744 case <-time.After(1 * time.Second):
1745 t.Errorf("timeout")
1746 }
1747
1748 }
1749
1750
1751
1752 func TestTransportDisableKeepAlives_Concurrency(t *testing.T) {
1753 const D = 25 * time.Millisecond
1754 ts := newTestServer(t,
1755 func(w http.ResponseWriter, r *http.Request) {
1756 time.Sleep(D)
1757 io.WriteString(w, "hi")
1758 },
1759 )
1760
1761 var dials int32
1762 var conns sync.WaitGroup
1763 tr := &Transport{
1764 t1: &http.Transport{
1765 DisableKeepAlives: true,
1766 },
1767 TLSClientConfig: tlsConfigInsecure,
1768 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
1769 tc, err := tls.Dial(network, addr, cfg)
1770 if err != nil {
1771 return nil, err
1772 }
1773 atomic.AddInt32(&dials, 1)
1774 conns.Add(1)
1775 return ¬eCloseConn{Conn: tc, closefn: func() { conns.Done() }}, nil
1776 },
1777 }
1778 c := &http.Client{Transport: tr}
1779 var reqs sync.WaitGroup
1780 const N = 20
1781 for i := 0; i < N; i++ {
1782 reqs.Add(1)
1783 if i == N-1 {
1784
1785
1786
1787
1788
1789
1790 time.Sleep(D * 2)
1791 }
1792 go func() {
1793 defer reqs.Done()
1794 res, err := c.Get(ts.URL)
1795 if err != nil {
1796 t.Error(err)
1797 return
1798 }
1799 if _, err := io.ReadAll(res.Body); err != nil {
1800 t.Error(err)
1801 return
1802 }
1803 res.Body.Close()
1804 }()
1805 }
1806 reqs.Wait()
1807 conns.Wait()
1808 t.Logf("did %d dials, %d requests", atomic.LoadInt32(&dials), N)
1809 }
1810
1811 type noteCloseConn struct {
1812 net.Conn
1813 onceClose sync.Once
1814 closefn func()
1815 }
1816
1817 func (c *noteCloseConn) Close() error {
1818 c.onceClose.Do(c.closefn)
1819 return c.Conn.Close()
1820 }
1821
1822 func isTimeout(err error) bool {
1823 switch err := err.(type) {
1824 case nil:
1825 return false
1826 case *url.Error:
1827 return isTimeout(err.Err)
1828 case net.Error:
1829 return err.Timeout()
1830 }
1831 return false
1832 }
1833
1834
1835 func TestTransportResponseHeaderTimeout_NoBody(t *testing.T) {
1836 testTransportResponseHeaderTimeout(t, false)
1837 }
1838 func TestTransportResponseHeaderTimeout_Body(t *testing.T) {
1839 testTransportResponseHeaderTimeout(t, true)
1840 }
1841
1842 func testTransportResponseHeaderTimeout(t *testing.T, body bool) {
1843 const bodySize = 4 << 20
1844 tc := newTestClientConn(t, func(tr *Transport) {
1845 tr.t1 = &http.Transport{
1846 ResponseHeaderTimeout: 5 * time.Millisecond,
1847 }
1848 })
1849 tc.greet()
1850
1851 var req *http.Request
1852 var reqBody *testRequestBody
1853 if body {
1854 reqBody = tc.newRequestBody()
1855 reqBody.writeBytes(bodySize)
1856 reqBody.closeWithError(io.EOF)
1857 req, _ = http.NewRequest("POST", "https://dummy.tld/", reqBody)
1858 req.Header.Set("Content-Type", "text/foo")
1859 } else {
1860 req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
1861 }
1862
1863 rt := tc.roundTrip(req)
1864
1865 tc.wantFrameType(FrameHeaders)
1866
1867 tc.writeWindowUpdate(0, bodySize)
1868 tc.writeWindowUpdate(rt.streamID(), bodySize)
1869
1870 if body {
1871 tc.wantData(wantData{
1872 endStream: true,
1873 size: bodySize,
1874 multiple: true,
1875 })
1876 }
1877
1878 tc.advance(4 * time.Millisecond)
1879 if rt.done() {
1880 t.Fatalf("RoundTrip is done after 4ms; want still waiting")
1881 }
1882 tc.advance(1 * time.Millisecond)
1883
1884 if err := rt.err(); !isTimeout(err) {
1885 t.Fatalf("RoundTrip error: %v; want timeout error", err)
1886 }
1887 }
1888
1889 func TestTransportDisableCompression(t *testing.T) {
1890 const body = "sup"
1891 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1892 want := http.Header{
1893 "User-Agent": []string{"Go-http-client/2.0"},
1894 }
1895 if !reflect.DeepEqual(r.Header, want) {
1896 t.Errorf("request headers = %v; want %v", r.Header, want)
1897 }
1898 })
1899
1900 tr := &Transport{
1901 TLSClientConfig: tlsConfigInsecure,
1902 t1: &http.Transport{
1903 DisableCompression: true,
1904 },
1905 }
1906 defer tr.CloseIdleConnections()
1907
1908 req, err := http.NewRequest("GET", ts.URL, nil)
1909 if err != nil {
1910 t.Fatal(err)
1911 }
1912 res, err := tr.RoundTrip(req)
1913 if err != nil {
1914 t.Fatal(err)
1915 }
1916 defer res.Body.Close()
1917 }
1918
1919
1920 func TestTransportRejectsConnHeaders(t *testing.T) {
1921 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
1922 var got []string
1923 for k := range r.Header {
1924 got = append(got, k)
1925 }
1926 sort.Strings(got)
1927 w.Header().Set("Got-Header", strings.Join(got, ","))
1928 })
1929
1930 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
1931 defer tr.CloseIdleConnections()
1932
1933 tests := []struct {
1934 key string
1935 value []string
1936 want string
1937 }{
1938 {
1939 key: "Upgrade",
1940 value: []string{"anything"},
1941 want: "ERROR: http2: invalid Upgrade request header: [\"anything\"]",
1942 },
1943 {
1944 key: "Connection",
1945 value: []string{"foo"},
1946 want: "ERROR: http2: invalid Connection request header: [\"foo\"]",
1947 },
1948 {
1949 key: "Connection",
1950 value: []string{"close"},
1951 want: "Accept-Encoding,User-Agent",
1952 },
1953 {
1954 key: "Connection",
1955 value: []string{"CLoSe"},
1956 want: "Accept-Encoding,User-Agent",
1957 },
1958 {
1959 key: "Connection",
1960 value: []string{"close", "something-else"},
1961 want: "ERROR: http2: invalid Connection request header: [\"close\" \"something-else\"]",
1962 },
1963 {
1964 key: "Connection",
1965 value: []string{"keep-alive"},
1966 want: "Accept-Encoding,User-Agent",
1967 },
1968 {
1969 key: "Connection",
1970 value: []string{"Keep-ALIVE"},
1971 want: "Accept-Encoding,User-Agent",
1972 },
1973 {
1974 key: "Proxy-Connection",
1975 value: []string{"keep-alive"},
1976 want: "Accept-Encoding,User-Agent",
1977 },
1978 {
1979 key: "Transfer-Encoding",
1980 value: []string{""},
1981 want: "Accept-Encoding,User-Agent",
1982 },
1983 {
1984 key: "Transfer-Encoding",
1985 value: []string{"foo"},
1986 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"foo\"]",
1987 },
1988 {
1989 key: "Transfer-Encoding",
1990 value: []string{"chunked"},
1991 want: "Accept-Encoding,User-Agent",
1992 },
1993 {
1994 key: "Transfer-Encoding",
1995 value: []string{"chunKed"},
1996 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunKed\"]",
1997 },
1998 {
1999 key: "Transfer-Encoding",
2000 value: []string{"chunked", "other"},
2001 want: "ERROR: http2: invalid Transfer-Encoding request header: [\"chunked\" \"other\"]",
2002 },
2003 {
2004 key: "Content-Length",
2005 value: []string{"123"},
2006 want: "Accept-Encoding,User-Agent",
2007 },
2008 {
2009 key: "Keep-Alive",
2010 value: []string{"doop"},
2011 want: "Accept-Encoding,User-Agent",
2012 },
2013 }
2014
2015 for _, tt := range tests {
2016 req, _ := http.NewRequest("GET", ts.URL, nil)
2017 req.Header[tt.key] = tt.value
2018 res, err := tr.RoundTrip(req)
2019 var got string
2020 if err != nil {
2021 got = fmt.Sprintf("ERROR: %v", err)
2022 } else {
2023 got = res.Header.Get("Got-Header")
2024 res.Body.Close()
2025 }
2026 if got != tt.want {
2027 t.Errorf("For key %q, value %q, got = %q; want %q", tt.key, tt.value, got, tt.want)
2028 }
2029 }
2030 }
2031
2032
2033
2034 func TestTransportRejectsContentLengthWithSign(t *testing.T) {
2035 tests := []struct {
2036 name string
2037 cl []string
2038 wantCL string
2039 }{
2040 {
2041 name: "proper content-length",
2042 cl: []string{"3"},
2043 wantCL: "3",
2044 },
2045 {
2046 name: "ignore cl with plus sign",
2047 cl: []string{"+3"},
2048 wantCL: "",
2049 },
2050 {
2051 name: "ignore cl with minus sign",
2052 cl: []string{"-3"},
2053 wantCL: "",
2054 },
2055 {
2056 name: "max int64, for safe uint64->int64 conversion",
2057 cl: []string{"9223372036854775807"},
2058 wantCL: "9223372036854775807",
2059 },
2060 {
2061 name: "overflows int64, so ignored",
2062 cl: []string{"9223372036854775808"},
2063 wantCL: "",
2064 },
2065 }
2066
2067 for _, tt := range tests {
2068 tt := tt
2069 t.Run(tt.name, func(t *testing.T) {
2070 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2071 w.Header().Set("Content-Length", tt.cl[0])
2072 })
2073 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2074 defer tr.CloseIdleConnections()
2075
2076 req, _ := http.NewRequest("HEAD", ts.URL, nil)
2077 res, err := tr.RoundTrip(req)
2078
2079 var got string
2080 if err != nil {
2081 got = fmt.Sprintf("ERROR: %v", err)
2082 } else {
2083 got = res.Header.Get("Content-Length")
2084 res.Body.Close()
2085 }
2086
2087 if got != tt.wantCL {
2088 t.Fatalf("Got: %q\nWant: %q", got, tt.wantCL)
2089 }
2090 })
2091 }
2092 }
2093
2094
2095
2096 func TestTransportFailsOnInvalidHeadersAndTrailers(t *testing.T) {
2097 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2098 var got []string
2099 for k := range r.Header {
2100 got = append(got, k)
2101 }
2102 sort.Strings(got)
2103 w.Header().Set("Got-Header", strings.Join(got, ","))
2104 })
2105
2106 tests := [...]struct {
2107 h http.Header
2108 t http.Header
2109 wantErr string
2110 }{
2111 0: {
2112 h: http.Header{"with space": {"foo"}},
2113 wantErr: `invalid HTTP header name "with space"`,
2114 },
2115 1: {
2116 h: http.Header{"name": {"Брэд"}},
2117 wantErr: "",
2118 },
2119 2: {
2120 h: http.Header{"имя": {"Brad"}},
2121 wantErr: `invalid HTTP header name "имя"`,
2122 },
2123 3: {
2124 h: http.Header{"foo": {"foo\x01bar"}},
2125 wantErr: `invalid HTTP header value for header "foo"`,
2126 },
2127 4: {
2128 t: http.Header{"foo": {"foo\x01bar"}},
2129 wantErr: `invalid HTTP trailer value for header "foo"`,
2130 },
2131 5: {
2132 t: http.Header{"x-\r\nda": {"foo\x01bar"}},
2133 wantErr: `invalid HTTP trailer name "x-\r\nda"`,
2134 },
2135 }
2136
2137 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2138 defer tr.CloseIdleConnections()
2139
2140 for i, tt := range tests {
2141 req, _ := http.NewRequest("GET", ts.URL, nil)
2142 req.Header = tt.h
2143 req.Trailer = tt.t
2144 res, err := tr.RoundTrip(req)
2145 var bad bool
2146 if tt.wantErr == "" {
2147 if err != nil {
2148 bad = true
2149 t.Errorf("case %d: error = %v; want no error", i, err)
2150 }
2151 } else {
2152 if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
2153 bad = true
2154 t.Errorf("case %d: error = %v; want error %q", i, err, tt.wantErr)
2155 }
2156 }
2157 if err == nil {
2158 if bad {
2159 t.Logf("case %d: server got headers %q", i, res.Header.Get("Got-Header"))
2160 }
2161 res.Body.Close()
2162 }
2163 }
2164 }
2165
2166
2167
2168 func TestGzipReader_DoubleReadCrash(t *testing.T) {
2169 gz := &gzipReader{
2170 body: io.NopCloser(strings.NewReader("0123456789")),
2171 }
2172 var buf [1]byte
2173 n, err1 := gz.Read(buf[:])
2174 if n != 0 || !strings.Contains(fmt.Sprint(err1), "invalid header") {
2175 t.Fatalf("Read = %v, %v; want 0, invalid header", n, err1)
2176 }
2177 n, err2 := gz.Read(buf[:])
2178 if n != 0 || err2 != err1 {
2179 t.Fatalf("second Read = %v, %v; want 0, %v", n, err2, err1)
2180 }
2181 }
2182
2183 func TestGzipReader_ReadAfterClose(t *testing.T) {
2184 body := bytes.Buffer{}
2185 w := gzip.NewWriter(&body)
2186 w.Write([]byte("012345679"))
2187 w.Close()
2188 gz := &gzipReader{
2189 body: io.NopCloser(&body),
2190 }
2191 var buf [1]byte
2192 n, err := gz.Read(buf[:])
2193 if n != 1 || err != nil {
2194 t.Fatalf("first Read = %v, %v; want 1, nil", n, err)
2195 }
2196 if err := gz.Close(); err != nil {
2197 t.Fatalf("gz Close error: %v", err)
2198 }
2199 n, err = gz.Read(buf[:])
2200 if n != 0 || err != fs.ErrClosed {
2201 t.Fatalf("Read after close = %v, %v; want 0, fs.ErrClosed", n, err)
2202 }
2203 }
2204
2205 func TestTransportNewTLSConfig(t *testing.T) {
2206 tests := [...]struct {
2207 conf *tls.Config
2208 host string
2209 want *tls.Config
2210 }{
2211
2212 0: {
2213 conf: nil,
2214 host: "foo.com",
2215 want: &tls.Config{
2216 ServerName: "foo.com",
2217 NextProtos: []string{NextProtoTLS},
2218 },
2219 },
2220
2221
2222 1: {
2223 conf: &tls.Config{
2224 ServerName: "bar.com",
2225 },
2226 host: "foo.com",
2227 want: &tls.Config{
2228 ServerName: "bar.com",
2229 NextProtos: []string{NextProtoTLS},
2230 },
2231 },
2232
2233
2234 2: {
2235 conf: &tls.Config{
2236 NextProtos: []string{"foo", "bar"},
2237 },
2238 host: "example.com",
2239 want: &tls.Config{
2240 ServerName: "example.com",
2241 NextProtos: []string{NextProtoTLS, "foo", "bar"},
2242 },
2243 },
2244
2245
2246 3: {
2247 conf: &tls.Config{
2248 NextProtos: []string{"foo", "bar", NextProtoTLS},
2249 },
2250 host: "example.com",
2251 want: &tls.Config{
2252 ServerName: "example.com",
2253 NextProtos: []string{"foo", "bar", NextProtoTLS},
2254 },
2255 },
2256 }
2257 for i, tt := range tests {
2258
2259
2260 if tt.conf != nil {
2261 tt.conf.SessionTicketsDisabled = true
2262 }
2263
2264 tr := &Transport{TLSClientConfig: tt.conf}
2265 got := tr.newTLSConfig(tt.host)
2266
2267 got.SessionTicketsDisabled = false
2268
2269 if !reflect.DeepEqual(got, tt.want) {
2270 t.Errorf("%d. got %#v; want %#v", i, got, tt.want)
2271 }
2272 }
2273 }
2274
2275
2276
2277
2278 func TestTransportReadHeadResponse(t *testing.T) {
2279 tc := newTestClientConn(t)
2280 tc.greet()
2281
2282 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
2283 rt := tc.roundTrip(req)
2284
2285 tc.wantFrameType(FrameHeaders)
2286 tc.writeHeaders(HeadersFrameParam{
2287 StreamID: rt.streamID(),
2288 EndHeaders: true,
2289 EndStream: false,
2290 BlockFragment: tc.makeHeaderBlockFragment(
2291 ":status", "200",
2292 "content-length", "123",
2293 ),
2294 })
2295 tc.writeData(rt.streamID(), true, nil)
2296
2297 res := rt.response()
2298 if res.ContentLength != 123 {
2299 t.Fatalf("Content-Length = %d; want 123", res.ContentLength)
2300 }
2301 rt.wantBody(nil)
2302 }
2303
2304 func TestTransportReadHeadResponseWithBody(t *testing.T) {
2305
2306
2307 log.SetOutput(io.Discard)
2308 defer log.SetOutput(os.Stderr)
2309
2310 response := "redirecting to /elsewhere"
2311 tc := newTestClientConn(t)
2312 tc.greet()
2313
2314 req, _ := http.NewRequest("HEAD", "https://dummy.tld/", nil)
2315 rt := tc.roundTrip(req)
2316
2317 tc.wantFrameType(FrameHeaders)
2318 tc.writeHeaders(HeadersFrameParam{
2319 StreamID: rt.streamID(),
2320 EndHeaders: true,
2321 EndStream: false,
2322 BlockFragment: tc.makeHeaderBlockFragment(
2323 ":status", "200",
2324 "content-length", strconv.Itoa(len(response)),
2325 ),
2326 })
2327 tc.writeData(rt.streamID(), true, []byte(response))
2328
2329 res := rt.response()
2330 if res.ContentLength != int64(len(response)) {
2331 t.Fatalf("Content-Length = %d; want %d", res.ContentLength, len(response))
2332 }
2333 rt.wantBody(nil)
2334 }
2335
2336 type neverEnding byte
2337
2338 func (b neverEnding) Read(p []byte) (int, error) {
2339 for i := range p {
2340 p[i] = byte(b)
2341 }
2342 return len(p), nil
2343 }
2344
2345
2346
2347
2348
2349 func TestTransportHandlerBodyClose(t *testing.T) {
2350 const bodySize = 10 << 20
2351 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2352 r.Body.Close()
2353 io.Copy(w, io.LimitReader(neverEnding('A'), bodySize))
2354 })
2355
2356 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2357 defer tr.CloseIdleConnections()
2358
2359 g0 := runtime.NumGoroutine()
2360
2361 const numReq = 10
2362 for i := 0; i < numReq; i++ {
2363 req, err := http.NewRequest("POST", ts.URL, struct{ io.Reader }{io.LimitReader(neverEnding('A'), bodySize)})
2364 if err != nil {
2365 t.Fatal(err)
2366 }
2367 res, err := tr.RoundTrip(req)
2368 if err != nil {
2369 t.Fatal(err)
2370 }
2371 n, err := io.Copy(io.Discard, res.Body)
2372 res.Body.Close()
2373 if n != bodySize || err != nil {
2374 t.Fatalf("req#%d: Copy = %d, %v; want %d, nil", i, n, err, bodySize)
2375 }
2376 }
2377 tr.CloseIdleConnections()
2378
2379 if !waitCondition(5*time.Second, 100*time.Millisecond, func() bool {
2380 gd := runtime.NumGoroutine() - g0
2381 return gd < numReq/2
2382 }) {
2383 t.Errorf("appeared to leak goroutines")
2384 }
2385 }
2386
2387
2388 func TestTransportFlowControl(t *testing.T) {
2389 const bufLen = 64 << 10
2390 var total int64 = 100 << 20
2391 if testing.Short() {
2392 total = 10 << 20
2393 }
2394
2395 var wrote int64
2396 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2397 b := make([]byte, bufLen)
2398 for wrote < total {
2399 n, err := w.Write(b)
2400 atomic.AddInt64(&wrote, int64(n))
2401 if err != nil {
2402 t.Errorf("ResponseWriter.Write error: %v", err)
2403 break
2404 }
2405 w.(http.Flusher).Flush()
2406 }
2407 })
2408
2409 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2410 defer tr.CloseIdleConnections()
2411 req, err := http.NewRequest("GET", ts.URL, nil)
2412 if err != nil {
2413 t.Fatal("NewRequest error:", err)
2414 }
2415 resp, err := tr.RoundTrip(req)
2416 if err != nil {
2417 t.Fatal("RoundTrip error:", err)
2418 }
2419 defer resp.Body.Close()
2420
2421 var read int64
2422 b := make([]byte, bufLen)
2423 for {
2424 n, err := resp.Body.Read(b)
2425 if err == io.EOF {
2426 break
2427 }
2428 if err != nil {
2429 t.Fatal("Read error:", err)
2430 }
2431 read += int64(n)
2432
2433 const max = transportDefaultStreamFlow
2434 if w := atomic.LoadInt64(&wrote); -max > read-w || read-w > max {
2435 t.Fatalf("Too much data inflight: server wrote %v bytes but client only received %v", w, read)
2436 }
2437
2438
2439 time.Sleep(1 * time.Millisecond)
2440 }
2441 }
2442
2443
2444
2445
2446
2447
2448 func TestTransportUsesGoAwayDebugError_RoundTrip(t *testing.T) {
2449 testTransportUsesGoAwayDebugError(t, false)
2450 }
2451
2452 func TestTransportUsesGoAwayDebugError_Body(t *testing.T) {
2453 testTransportUsesGoAwayDebugError(t, true)
2454 }
2455
2456 func testTransportUsesGoAwayDebugError(t *testing.T, failMidBody bool) {
2457 tc := newTestClientConn(t)
2458 tc.greet()
2459
2460 const goAwayErrCode = ErrCodeHTTP11Required
2461 const goAwayDebugData = "some debug data"
2462
2463 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2464 rt := tc.roundTrip(req)
2465
2466 tc.wantFrameType(FrameHeaders)
2467
2468 if failMidBody {
2469 tc.writeHeaders(HeadersFrameParam{
2470 StreamID: rt.streamID(),
2471 EndHeaders: true,
2472 EndStream: false,
2473 BlockFragment: tc.makeHeaderBlockFragment(
2474 ":status", "200",
2475 "content-length", "123",
2476 ),
2477 })
2478 }
2479
2480
2481
2482 tc.writeGoAway(5, ErrCodeNo, []byte(goAwayDebugData))
2483 tc.writeGoAway(5, goAwayErrCode, nil)
2484 tc.closeWrite()
2485
2486 res, err := rt.result()
2487 whence := "RoundTrip"
2488 if failMidBody {
2489 whence = "Body.Read"
2490 if err != nil {
2491 t.Fatalf("RoundTrip error = %v, want success", err)
2492 }
2493 _, err = res.Body.Read(make([]byte, 1))
2494 }
2495
2496 want := GoAwayError{
2497 LastStreamID: 5,
2498 ErrCode: goAwayErrCode,
2499 DebugData: goAwayDebugData,
2500 }
2501 if !reflect.DeepEqual(err, want) {
2502 t.Errorf("%v error = %T: %#v, want %T (%#v)", whence, err, err, want, want)
2503 }
2504 }
2505
2506 func testTransportReturnsUnusedFlowControl(t *testing.T, oneDataFrame bool) {
2507 tc := newTestClientConn(t)
2508 tc.greet()
2509
2510 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2511 rt := tc.roundTrip(req)
2512
2513 tc.wantFrameType(FrameHeaders)
2514 tc.writeHeaders(HeadersFrameParam{
2515 StreamID: rt.streamID(),
2516 EndHeaders: true,
2517 EndStream: false,
2518 BlockFragment: tc.makeHeaderBlockFragment(
2519 ":status", "200",
2520 "content-length", "5000",
2521 ),
2522 })
2523 initialInflow := tc.inflowWindow(0)
2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535 const streamNotEnded = false
2536 if oneDataFrame {
2537 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 5000))
2538 } else {
2539 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 1))
2540 }
2541
2542 res := rt.response()
2543 if n, err := res.Body.Read(make([]byte, 1)); err != nil || n != 1 {
2544 t.Fatalf("body read = %v, %v; want 1, nil", n, err)
2545 }
2546 res.Body.Close()
2547 tc.sync()
2548
2549 sentAdditionalData := false
2550 tc.wantUnorderedFrames(
2551 func(f *RSTStreamFrame) bool {
2552 if f.ErrCode != ErrCodeCancel {
2553 t.Fatalf("Expected a RSTStreamFrame with code cancel; got %v", summarizeFrame(f))
2554 }
2555 if !oneDataFrame {
2556
2557 tc.writeData(rt.streamID(), streamNotEnded, make([]byte, 4999))
2558 sentAdditionalData = true
2559 }
2560 return true
2561 },
2562 func(f *PingFrame) bool {
2563 return true
2564 },
2565 func(f *WindowUpdateFrame) bool {
2566 if !oneDataFrame && !sentAdditionalData {
2567 t.Fatalf("Got WindowUpdateFrame, don't expect one yet")
2568 }
2569 if f.Increment != 5000 {
2570 t.Fatalf("Expected WindowUpdateFrames for 5000 bytes; got %v", summarizeFrame(f))
2571 }
2572 return true
2573 },
2574 )
2575
2576 if got, want := tc.inflowWindow(0), initialInflow; got != want {
2577 t.Fatalf("connection flow tokens = %v, want %v", got, want)
2578 }
2579 }
2580
2581
2582 func TestTransportReturnsUnusedFlowControlSingleWrite(t *testing.T) {
2583 testTransportReturnsUnusedFlowControl(t, true)
2584 }
2585
2586
2587 func TestTransportReturnsUnusedFlowControlMultipleWrites(t *testing.T) {
2588 testTransportReturnsUnusedFlowControl(t, false)
2589 }
2590
2591
2592
2593 func TestTransportAdjustsFlowControl(t *testing.T) {
2594 const bodySize = 1 << 20
2595
2596 tc := newTestClientConn(t)
2597 tc.wantFrameType(FrameSettings)
2598 tc.wantFrameType(FrameWindowUpdate)
2599
2600
2601 body := tc.newRequestBody()
2602 body.writeBytes(bodySize)
2603 body.closeWithError(io.EOF)
2604
2605 req, _ := http.NewRequest("POST", "https://dummy.tld/", body)
2606 rt := tc.roundTrip(req)
2607
2608 tc.wantFrameType(FrameHeaders)
2609
2610 gotBytes := int64(0)
2611 for {
2612 f := readFrame[*DataFrame](t, tc)
2613 gotBytes += int64(len(f.Data()))
2614
2615
2616 if gotBytes >= initialWindowSize/2 {
2617 break
2618 }
2619 }
2620
2621 tc.writeSettings(Setting{ID: SettingInitialWindowSize, Val: bodySize})
2622 tc.writeWindowUpdate(0, bodySize)
2623 tc.writeSettingsAck()
2624
2625 tc.wantUnorderedFrames(
2626 func(f *SettingsFrame) bool { return true },
2627 func(f *DataFrame) bool {
2628 gotBytes += int64(len(f.Data()))
2629 return f.StreamEnded()
2630 },
2631 )
2632
2633 if gotBytes != bodySize {
2634 t.Fatalf("server received %v bytes of body, want %v", gotBytes, bodySize)
2635 }
2636
2637 tc.writeHeaders(HeadersFrameParam{
2638 StreamID: rt.streamID(),
2639 EndHeaders: true,
2640 EndStream: true,
2641 BlockFragment: tc.makeHeaderBlockFragment(
2642 ":status", "200",
2643 ),
2644 })
2645 rt.wantStatus(200)
2646 }
2647
2648
2649 func TestTransportReturnsDataPaddingFlowControl(t *testing.T) {
2650 tc := newTestClientConn(t)
2651 tc.greet()
2652
2653 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2654 rt := tc.roundTrip(req)
2655
2656 tc.wantFrameType(FrameHeaders)
2657 tc.writeHeaders(HeadersFrameParam{
2658 StreamID: rt.streamID(),
2659 EndHeaders: true,
2660 EndStream: false,
2661 BlockFragment: tc.makeHeaderBlockFragment(
2662 ":status", "200",
2663 "content-length", "5000",
2664 ),
2665 })
2666
2667 initialConnWindow := tc.inflowWindow(0)
2668 initialStreamWindow := tc.inflowWindow(rt.streamID())
2669
2670 pad := make([]byte, 5)
2671 tc.writeDataPadded(rt.streamID(), false, make([]byte, 5000), pad)
2672
2673
2674 if got, want := tc.inflowWindow(0), initialConnWindow-5000; got != want {
2675 t.Errorf("conn inflow window = %v, want %v", got, want)
2676 }
2677 if got, want := tc.inflowWindow(rt.streamID()), initialStreamWindow-5000; got != want {
2678 t.Errorf("stream inflow window = %v, want %v", got, want)
2679 }
2680 }
2681
2682
2683
2684 func TestTransportReturnsErrorOnBadResponseHeaders(t *testing.T) {
2685 tc := newTestClientConn(t)
2686 tc.greet()
2687
2688 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
2689 rt := tc.roundTrip(req)
2690
2691 tc.wantFrameType(FrameHeaders)
2692 tc.writeHeaders(HeadersFrameParam{
2693 StreamID: rt.streamID(),
2694 EndHeaders: true,
2695 EndStream: false,
2696 BlockFragment: tc.makeHeaderBlockFragment(
2697 ":status", "200",
2698 " content-type", "bogus",
2699 ),
2700 })
2701
2702 err := rt.err()
2703 want := StreamError{1, ErrCodeProtocol, headerFieldNameError(" content-type")}
2704 if !reflect.DeepEqual(err, want) {
2705 t.Fatalf("RoundTrip error = %#v; want %#v", err, want)
2706 }
2707
2708 fr := readFrame[*RSTStreamFrame](t, tc)
2709 if fr.StreamID != 1 || fr.ErrCode != ErrCodeProtocol {
2710 t.Errorf("Frame = %v; want RST_STREAM for stream 1 with ErrCodeProtocol", summarizeFrame(fr))
2711 }
2712 }
2713
2714
2715
2716 type byteAndEOFReader byte
2717
2718 func (b byteAndEOFReader) Read(p []byte) (n int, err error) {
2719 if len(p) == 0 {
2720 panic("unexpected useless call")
2721 }
2722 p[0] = byte(b)
2723 return 1, io.EOF
2724 }
2725
2726
2727
2728
2729
2730
2731
2732
2733
2734
2735 func TestTransportBodyDoubleEndStream(t *testing.T) {
2736 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2737
2738 })
2739
2740 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2741 defer tr.CloseIdleConnections()
2742
2743 for i := 0; i < 2; i++ {
2744 req, _ := http.NewRequest("POST", ts.URL, byteAndEOFReader('a'))
2745 req.ContentLength = 1
2746 res, err := tr.RoundTrip(req)
2747 if err != nil {
2748 t.Fatalf("failure on req %d: %v", i+1, err)
2749 }
2750 defer res.Body.Close()
2751 }
2752 }
2753
2754
2755 func TestTransportRequestPathPseudo(t *testing.T) {
2756 type result struct {
2757 path string
2758 err string
2759 }
2760 tests := []struct {
2761 req *http.Request
2762 want result
2763 }{
2764 0: {
2765 req: &http.Request{
2766 Method: "GET",
2767 URL: &url.URL{
2768 Host: "foo.com",
2769 Path: "/foo",
2770 },
2771 },
2772 want: result{path: "/foo"},
2773 },
2774
2775
2776
2777 1: {
2778 req: &http.Request{
2779 Method: "GET",
2780 URL: &url.URL{
2781 Host: "foo.com",
2782 Path: "//foo",
2783 },
2784 },
2785 want: result{path: "//foo"},
2786 },
2787
2788
2789 2: {
2790 req: &http.Request{
2791 Method: "GET",
2792 URL: &url.URL{
2793 Scheme: "https",
2794 Opaque: "//foo.com/path",
2795 Host: "foo.com",
2796 Path: "/ignored",
2797 },
2798 },
2799 want: result{path: "/path"},
2800 },
2801
2802
2803 3: {
2804 req: &http.Request{
2805 Method: "GET",
2806 Host: "bar.com",
2807 URL: &url.URL{
2808 Scheme: "https",
2809 Opaque: "//bar.com/path",
2810 Host: "foo.com",
2811 Path: "/ignored",
2812 },
2813 },
2814 want: result{path: "/path"},
2815 },
2816
2817
2818 4: {
2819 req: &http.Request{
2820 Method: "GET",
2821 URL: &url.URL{
2822 Opaque: "/path",
2823 Host: "foo.com",
2824 Path: "/ignored",
2825 },
2826 },
2827 want: result{path: "/path"},
2828 },
2829
2830
2831 5: {
2832 req: &http.Request{
2833 Method: "GET",
2834 URL: &url.URL{
2835 Scheme: "https",
2836 Opaque: "//unknown_host/path",
2837 Host: "foo.com",
2838 Path: "/ignored",
2839 },
2840 },
2841 want: result{err: `invalid request :path "https://unknown_host/path" from URL.Opaque = "//unknown_host/path"`},
2842 },
2843
2844
2845 6: {
2846 req: &http.Request{
2847 Method: "CONNECT",
2848 URL: &url.URL{
2849 Host: "foo.com",
2850 },
2851 },
2852 want: result{},
2853 },
2854 }
2855 for i, tt := range tests {
2856 cc := &ClientConn{peerMaxHeaderListSize: 0xffffffffffffffff}
2857 cc.henc = hpack.NewEncoder(&cc.hbuf)
2858 cc.mu.Lock()
2859 hdrs, err := cc.encodeHeaders(tt.req, false, "", -1)
2860 cc.mu.Unlock()
2861 var got result
2862 hpackDec := hpack.NewDecoder(initialHeaderTableSize, func(f hpack.HeaderField) {
2863 if f.Name == ":path" {
2864 got.path = f.Value
2865 }
2866 })
2867 if err != nil {
2868 got.err = err.Error()
2869 } else if len(hdrs) > 0 {
2870 if _, err := hpackDec.Write(hdrs); err != nil {
2871 t.Errorf("%d. bogus hpack: %v", i, err)
2872 continue
2873 }
2874 }
2875 if got != tt.want {
2876 t.Errorf("%d. got %+v; want %+v", i, got, tt.want)
2877 }
2878
2879 }
2880
2881 }
2882
2883
2884
2885 func TestRoundTripDoesntConsumeRequestBodyEarly(t *testing.T) {
2886 const body = "foo"
2887 req, _ := http.NewRequest("POST", "http://foo.com/", io.NopCloser(strings.NewReader(body)))
2888 cc := &ClientConn{
2889 closed: true,
2890 reqHeaderMu: make(chan struct{}, 1),
2891 t: &Transport{},
2892 }
2893 _, err := cc.RoundTrip(req)
2894 if err != errClientConnUnusable {
2895 t.Fatalf("RoundTrip = %v; want errClientConnUnusable", err)
2896 }
2897 slurp, err := io.ReadAll(req.Body)
2898 if err != nil {
2899 t.Errorf("ReadAll = %v", err)
2900 }
2901 if string(slurp) != body {
2902 t.Errorf("Body = %q; want %q", slurp, body)
2903 }
2904 }
2905
2906 func TestClientConnPing(t *testing.T) {
2907 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {})
2908 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2909 defer tr.CloseIdleConnections()
2910 ctx := context.Background()
2911 cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
2912 if err != nil {
2913 t.Fatal(err)
2914 }
2915 if err = cc.Ping(context.Background()); err != nil {
2916 t.Fatal(err)
2917 }
2918 }
2919
2920
2921
2922
2923
2924 func TestTransportCancelDataResponseRace(t *testing.T) {
2925 cancel := make(chan struct{})
2926 clientGotResponse := make(chan bool, 1)
2927
2928 const msg = "Hello."
2929 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2930 if strings.Contains(r.URL.Path, "/hello") {
2931 time.Sleep(50 * time.Millisecond)
2932 io.WriteString(w, msg)
2933 return
2934 }
2935 for i := 0; i < 50; i++ {
2936 io.WriteString(w, "Some data.")
2937 w.(http.Flusher).Flush()
2938 if i == 2 {
2939 <-clientGotResponse
2940 close(cancel)
2941 }
2942 time.Sleep(10 * time.Millisecond)
2943 }
2944 })
2945
2946 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2947 defer tr.CloseIdleConnections()
2948
2949 c := &http.Client{Transport: tr}
2950 req, _ := http.NewRequest("GET", ts.URL, nil)
2951 req.Cancel = cancel
2952 res, err := c.Do(req)
2953 clientGotResponse <- true
2954 if err != nil {
2955 t.Fatal(err)
2956 }
2957 if _, err = io.Copy(io.Discard, res.Body); err == nil {
2958 t.Fatal("unexpected success")
2959 }
2960
2961 res, err = c.Get(ts.URL + "/hello")
2962 if err != nil {
2963 t.Fatal(err)
2964 }
2965 slurp, err := io.ReadAll(res.Body)
2966 if err != nil {
2967 t.Fatal(err)
2968 }
2969 if string(slurp) != msg {
2970 t.Errorf("Got = %q; want %q", slurp, msg)
2971 }
2972 }
2973
2974
2975
2976 func TestTransportNoRaceOnRequestObjectAfterRequestComplete(t *testing.T) {
2977 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2978 w.WriteHeader(200)
2979 io.WriteString(w, "body")
2980 })
2981
2982 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2983 defer tr.CloseIdleConnections()
2984
2985 req, _ := http.NewRequest("GET", ts.URL, nil)
2986 resp, err := tr.RoundTrip(req)
2987 if err != nil {
2988 t.Fatal(err)
2989 }
2990 if _, err = io.Copy(io.Discard, resp.Body); err != nil {
2991 t.Fatalf("error reading response body: %v", err)
2992 }
2993 if err := resp.Body.Close(); err != nil {
2994 t.Fatalf("error closing response body: %v", err)
2995 }
2996
2997
2998 req.Header = http.Header{}
2999 }
3000
3001 func TestTransportCloseAfterLostPing(t *testing.T) {
3002 tc := newTestClientConn(t, func(tr *Transport) {
3003 tr.PingTimeout = 1 * time.Second
3004 tr.ReadIdleTimeout = 1 * time.Second
3005 })
3006 tc.greet()
3007
3008 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3009 rt := tc.roundTrip(req)
3010 tc.wantFrameType(FrameHeaders)
3011
3012 tc.advance(1 * time.Second)
3013 tc.wantFrameType(FramePing)
3014
3015 tc.advance(1 * time.Second)
3016 err := rt.err()
3017 if err == nil || !strings.Contains(err.Error(), "client connection lost") {
3018 t.Fatalf("expected to get error about \"connection lost\", got %v", err)
3019 }
3020 }
3021
3022 func TestTransportPingWriteBlocks(t *testing.T) {
3023 ts := newTestServer(t,
3024 func(w http.ResponseWriter, r *http.Request) {},
3025 )
3026 tr := &Transport{
3027 TLSClientConfig: tlsConfigInsecure,
3028 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
3029 s, c := net.Pipe()
3030 go func() {
3031
3032
3033
3034 var buf [1024]byte
3035 s.Read(buf[:])
3036 }()
3037 return c, nil
3038 },
3039 PingTimeout: 1 * time.Millisecond,
3040 ReadIdleTimeout: 1 * time.Millisecond,
3041 }
3042 defer tr.CloseIdleConnections()
3043 c := &http.Client{Transport: tr}
3044 _, err := c.Get(ts.URL)
3045 if err == nil {
3046 t.Fatalf("Get = nil, want error")
3047 }
3048 }
3049
3050 func TestTransportPingWhenReadingMultiplePings(t *testing.T) {
3051 tc := newTestClientConn(t, func(tr *Transport) {
3052 tr.ReadIdleTimeout = 1000 * time.Millisecond
3053 })
3054 tc.greet()
3055
3056 ctx, cancel := context.WithCancel(context.Background())
3057 req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
3058 rt := tc.roundTrip(req)
3059
3060 tc.wantFrameType(FrameHeaders)
3061 tc.writeHeaders(HeadersFrameParam{
3062 StreamID: rt.streamID(),
3063 EndHeaders: true,
3064 EndStream: false,
3065 BlockFragment: tc.makeHeaderBlockFragment(
3066 ":status", "200",
3067 ),
3068 })
3069
3070 for i := 0; i < 5; i++ {
3071
3072 tc.advance(999 * time.Millisecond)
3073 if f := tc.readFrame(); f != nil {
3074 t.Fatalf("unexpected frame: %v", f)
3075 }
3076
3077
3078 tc.advance(1 * time.Millisecond)
3079 f := readFrame[*PingFrame](t, tc)
3080 tc.writePing(true, f.Data)
3081 }
3082
3083
3084 cancel()
3085 tc.sync()
3086
3087 tc.wantFrameType(FrameRSTStream)
3088 _, err := rt.readBody()
3089 if err == nil {
3090 t.Fatalf("Response.Body.Read() = %v, want error", err)
3091 }
3092 }
3093
3094 func TestTransportPingWhenReadingPingDisabled(t *testing.T) {
3095 tc := newTestClientConn(t, func(tr *Transport) {
3096 tr.ReadIdleTimeout = 0
3097 })
3098 tc.greet()
3099
3100 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3101 rt := tc.roundTrip(req)
3102
3103 tc.wantFrameType(FrameHeaders)
3104 tc.writeHeaders(HeadersFrameParam{
3105 StreamID: rt.streamID(),
3106 EndHeaders: true,
3107 EndStream: false,
3108 BlockFragment: tc.makeHeaderBlockFragment(
3109 ":status", "200",
3110 ),
3111 })
3112
3113
3114 tc.advance(1 * time.Minute)
3115 if f := tc.readFrame(); f != nil {
3116 t.Fatalf("unexpected frame: %v", f)
3117 }
3118 }
3119
3120 func TestTransportRetryAfterGOAWAYNoRetry(t *testing.T) {
3121 tt := newTestTransport(t)
3122
3123 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3124 rt := tt.roundTrip(req)
3125
3126
3127
3128
3129
3130 tc := tt.getConn()
3131 tc.wantFrameType(FrameSettings)
3132 tc.wantFrameType(FrameWindowUpdate)
3133 tc.wantHeaders(wantHeader{
3134 streamID: 1,
3135 endStream: true,
3136 })
3137 tc.writeSettings()
3138 tc.writeGoAway(0 , ErrCodeInternal, nil)
3139 if rt.err() == nil {
3140 t.Fatalf("after GOAWAY, RoundTrip is not done, want error")
3141 }
3142 }
3143
3144 func TestTransportRetryAfterGOAWAYRetry(t *testing.T) {
3145 tt := newTestTransport(t)
3146
3147 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3148 rt := tt.roundTrip(req)
3149
3150
3151
3152
3153
3154 tc := tt.getConn()
3155 tc.wantFrameType(FrameSettings)
3156 tc.wantFrameType(FrameWindowUpdate)
3157 tc.wantHeaders(wantHeader{
3158 streamID: 1,
3159 endStream: true,
3160 })
3161 tc.writeSettings()
3162 tc.writeGoAway(0 , ErrCodeNo, nil)
3163 if rt.done() {
3164 t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
3165 }
3166
3167
3168 tc = tt.getConn()
3169 tc.wantFrameType(FrameSettings)
3170 tc.wantFrameType(FrameWindowUpdate)
3171 tc.wantHeaders(wantHeader{
3172 streamID: 1,
3173 endStream: true,
3174 })
3175 tc.writeSettings()
3176 tc.writeHeaders(HeadersFrameParam{
3177 StreamID: 1,
3178 EndHeaders: true,
3179 EndStream: true,
3180 BlockFragment: tc.makeHeaderBlockFragment(
3181 ":status", "200",
3182 ),
3183 })
3184
3185 rt.wantStatus(200)
3186 }
3187
3188 func TestTransportRetryAfterGOAWAYSecondRequest(t *testing.T) {
3189 tt := newTestTransport(t)
3190
3191
3192 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3193 rt1 := tt.roundTrip(req)
3194 tc := tt.getConn()
3195 tc.wantFrameType(FrameSettings)
3196 tc.wantFrameType(FrameWindowUpdate)
3197 tc.wantHeaders(wantHeader{
3198 streamID: 1,
3199 endStream: true,
3200 })
3201 tc.writeSettings()
3202 tc.wantFrameType(FrameSettings)
3203 tc.writeHeaders(HeadersFrameParam{
3204 StreamID: 1,
3205 EndHeaders: true,
3206 EndStream: true,
3207 BlockFragment: tc.makeHeaderBlockFragment(
3208 ":status", "200",
3209 ),
3210 })
3211 rt1.wantStatus(200)
3212
3213
3214
3215
3216
3217 req, _ = http.NewRequest("GET", "https://dummy.tld/", nil)
3218 rt2 := tt.roundTrip(req)
3219
3220
3221 tc.wantHeaders(wantHeader{
3222 streamID: 3,
3223 endStream: true,
3224 })
3225 tc.writeSettings()
3226 tc.writeGoAway(1 , ErrCodeProtocol, nil)
3227 if rt2.done() {
3228 t.Fatalf("after GOAWAY, RoundTrip is done; want it to be retrying")
3229 }
3230
3231
3232 tc = tt.getConn()
3233 tc.wantFrameType(FrameSettings)
3234 tc.wantFrameType(FrameWindowUpdate)
3235 tc.wantHeaders(wantHeader{
3236 streamID: 1,
3237 endStream: true,
3238 })
3239 tc.writeSettings()
3240 tc.writeHeaders(HeadersFrameParam{
3241 StreamID: 1,
3242 EndHeaders: true,
3243 EndStream: true,
3244 BlockFragment: tc.makeHeaderBlockFragment(
3245 ":status", "200",
3246 ),
3247 })
3248 rt2.wantStatus(200)
3249 }
3250
3251 func TestTransportRetryAfterRefusedStream(t *testing.T) {
3252 tt := newTestTransport(t)
3253
3254 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3255 rt := tt.roundTrip(req)
3256
3257
3258 tc := tt.getConn()
3259 tc.wantFrameType(FrameSettings)
3260 tc.wantFrameType(FrameWindowUpdate)
3261 tc.wantHeaders(wantHeader{
3262 streamID: 1,
3263 endStream: true,
3264 })
3265 tc.writeSettings()
3266 tc.wantFrameType(FrameSettings)
3267 tc.writeRSTStream(1, ErrCodeRefusedStream)
3268 if rt.done() {
3269 t.Fatalf("after RST_STREAM, RoundTrip is done; want it to be retrying")
3270 }
3271
3272
3273 tc.wantHeaders(wantHeader{
3274 streamID: 3,
3275 endStream: true,
3276 })
3277 tc.writeSettings()
3278 tc.writeHeaders(HeadersFrameParam{
3279 StreamID: 3,
3280 EndHeaders: true,
3281 EndStream: true,
3282 BlockFragment: tc.makeHeaderBlockFragment(
3283 ":status", "204",
3284 ),
3285 })
3286
3287 rt.wantStatus(204)
3288 }
3289
3290 func TestTransportRetryHasLimit(t *testing.T) {
3291 tt := newTestTransport(t)
3292
3293 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3294 rt := tt.roundTrip(req)
3295
3296
3297 tc := tt.getConn()
3298 tc.wantFrameType(FrameSettings)
3299 tc.wantFrameType(FrameWindowUpdate)
3300
3301 var totalDelay time.Duration
3302 count := 0
3303 for streamID := uint32(1); ; streamID += 2 {
3304 count++
3305 tc.wantHeaders(wantHeader{
3306 streamID: streamID,
3307 endStream: true,
3308 })
3309 if streamID == 1 {
3310 tc.writeSettings()
3311 tc.wantFrameType(FrameSettings)
3312 }
3313 tc.writeRSTStream(streamID, ErrCodeRefusedStream)
3314
3315 d, scheduled := tt.group.TimeUntilEvent()
3316 if !scheduled {
3317 if streamID == 1 {
3318 continue
3319 }
3320 break
3321 }
3322 totalDelay += d
3323 if totalDelay > 5*time.Minute {
3324 t.Fatalf("RoundTrip still retrying after %v, should have given up", totalDelay)
3325 }
3326 tt.advance(d)
3327 }
3328 if got, want := count, 5; got < count {
3329 t.Errorf("RoundTrip made %v attempts, want at least %v", got, want)
3330 }
3331 if rt.err() == nil {
3332 t.Errorf("RoundTrip succeeded, want error")
3333 }
3334 }
3335
3336 func TestTransportResponseDataBeforeHeaders(t *testing.T) {
3337
3338 log.SetOutput(io.Discard)
3339 t.Cleanup(func() { log.SetOutput(os.Stderr) })
3340
3341 tc := newTestClientConn(t)
3342 tc.greet()
3343
3344
3345 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3346 rt1 := tc.roundTrip(req)
3347 tc.wantFrameType(FrameHeaders)
3348 tc.writeHeaders(HeadersFrameParam{
3349 StreamID: rt1.streamID(),
3350 EndHeaders: true,
3351 EndStream: true,
3352 BlockFragment: tc.makeHeaderBlockFragment(
3353 ":status", "200",
3354 ),
3355 })
3356 rt1.wantStatus(200)
3357
3358
3359 rt2 := tc.roundTrip(req)
3360 tc.wantFrameType(FrameHeaders)
3361 tc.writeData(rt2.streamID(), true, []byte("payload"))
3362 if err, ok := rt2.err().(StreamError); !ok || err.Code != ErrCodeProtocol {
3363 t.Fatalf("expected stream PROTOCOL_ERROR, got: %v", err)
3364 }
3365 }
3366
3367 func TestTransportMaxFrameReadSize(t *testing.T) {
3368 for _, test := range []struct {
3369 maxReadFrameSize uint32
3370 want uint32
3371 }{{
3372 maxReadFrameSize: 64000,
3373 want: 64000,
3374 }, {
3375 maxReadFrameSize: 1024,
3376 want: minMaxFrameSize,
3377 }} {
3378 t.Run(fmt.Sprint(test.maxReadFrameSize), func(t *testing.T) {
3379 tc := newTestClientConn(t, func(tr *Transport) {
3380 tr.MaxReadFrameSize = test.maxReadFrameSize
3381 })
3382
3383 fr := readFrame[*SettingsFrame](t, tc)
3384 got, ok := fr.Value(SettingMaxFrameSize)
3385 if !ok {
3386 t.Errorf("Transport.MaxReadFrameSize = %v; server got no setting, want %v", test.maxReadFrameSize, test.want)
3387 } else if got != test.want {
3388 t.Errorf("Transport.MaxReadFrameSize = %v; server got %v, want %v", test.maxReadFrameSize, got, test.want)
3389 }
3390 })
3391 }
3392 }
3393
3394 func TestTransportRequestsLowServerLimit(t *testing.T) {
3395 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3396 }, func(s *Server) {
3397 s.MaxConcurrentStreams = 1
3398 })
3399
3400 var (
3401 connCountMu sync.Mutex
3402 connCount int
3403 )
3404 tr := &Transport{
3405 TLSClientConfig: tlsConfigInsecure,
3406 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
3407 connCountMu.Lock()
3408 defer connCountMu.Unlock()
3409 connCount++
3410 return tls.Dial(network, addr, cfg)
3411 },
3412 }
3413 defer tr.CloseIdleConnections()
3414
3415 const reqCount = 3
3416 for i := 0; i < reqCount; i++ {
3417 req, err := http.NewRequest("GET", ts.URL, nil)
3418 if err != nil {
3419 t.Fatal(err)
3420 }
3421 res, err := tr.RoundTrip(req)
3422 if err != nil {
3423 t.Fatal(err)
3424 }
3425 if got, want := res.StatusCode, 200; got != want {
3426 t.Errorf("StatusCode = %v; want %v", got, want)
3427 }
3428 if res != nil && res.Body != nil {
3429 res.Body.Close()
3430 }
3431 }
3432
3433 if connCount != 1 {
3434 t.Errorf("created %v connections for %v requests, want 1", connCount, reqCount)
3435 }
3436 }
3437
3438
3439 func TestTransportRequestsStallAtServerLimit(t *testing.T) {
3440 const maxConcurrent = 2
3441
3442 tc := newTestClientConn(t, func(tr *Transport) {
3443 tr.StrictMaxConcurrentStreams = true
3444 })
3445 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
3446
3447 cancelClientRequest := make(chan struct{})
3448
3449
3450
3451 var rts []*testRoundTrip
3452 for k := 0; k < maxConcurrent+2; k++ {
3453 req, _ := http.NewRequest("GET", fmt.Sprintf("https://dummy.tld/%d", k), nil)
3454 if k == maxConcurrent {
3455 req.Cancel = cancelClientRequest
3456 }
3457 rt := tc.roundTrip(req)
3458 rts = append(rts, rt)
3459
3460 if k < maxConcurrent {
3461
3462 tc.wantHeaders(wantHeader{
3463 streamID: rt.streamID(),
3464 endStream: true,
3465 header: http.Header{
3466 ":authority": []string{"dummy.tld"},
3467 ":method": []string{"GET"},
3468 ":path": []string{fmt.Sprintf("/%d", k)},
3469 },
3470 })
3471 } else {
3472
3473
3474 if fr := tc.readFrame(); fr != nil {
3475 t.Fatalf("after making new request while at stream limit, got unexpected frame: %v", fr)
3476 }
3477 }
3478
3479 if rt.done() {
3480 t.Fatalf("rt %v done", k)
3481 }
3482 }
3483
3484
3485
3486 close(cancelClientRequest)
3487 tc.sync()
3488 if err := rts[maxConcurrent].err(); err == nil {
3489 t.Fatalf("RoundTrip(%d) should have failed due to cancel, did not", maxConcurrent)
3490 }
3491
3492
3493 for i, rt := range rts {
3494 if i != maxConcurrent && rt.done() {
3495 t.Fatalf("RoundTrip(%d) is done, but should not be", i)
3496 }
3497 }
3498
3499
3500 tc.writeHeaders(HeadersFrameParam{
3501 StreamID: rts[0].streamID(),
3502 EndHeaders: true,
3503 EndStream: true,
3504 BlockFragment: tc.makeHeaderBlockFragment(
3505 ":status", "200",
3506 ),
3507 })
3508 tc.wantHeaders(wantHeader{
3509 streamID: rts[maxConcurrent+1].streamID(),
3510 endStream: true,
3511 header: http.Header{
3512 ":authority": []string{"dummy.tld"},
3513 ":method": []string{"GET"},
3514 ":path": []string{fmt.Sprintf("/%d", maxConcurrent+1)},
3515 },
3516 })
3517 rts[0].wantStatus(200)
3518 }
3519
3520 func TestTransportMaxDecoderHeaderTableSize(t *testing.T) {
3521 var reqSize, resSize uint32 = 8192, 16384
3522 tc := newTestClientConn(t, func(tr *Transport) {
3523 tr.MaxDecoderHeaderTableSize = reqSize
3524 })
3525
3526 fr := readFrame[*SettingsFrame](t, tc)
3527 if v, ok := fr.Value(SettingHeaderTableSize); !ok {
3528 t.Fatalf("missing SETTINGS_HEADER_TABLE_SIZE setting")
3529 } else if v != reqSize {
3530 t.Fatalf("received SETTINGS_HEADER_TABLE_SIZE = %d, want %d", v, reqSize)
3531 }
3532
3533 tc.writeSettings(Setting{SettingHeaderTableSize, resSize})
3534 tc.cc.mu.Lock()
3535 defer tc.cc.mu.Unlock()
3536 if got, want := tc.cc.peerMaxHeaderTableSize, resSize; got != want {
3537 t.Fatalf("peerHeaderTableSize = %d, want %d", got, want)
3538 }
3539 }
3540
3541 func TestTransportMaxEncoderHeaderTableSize(t *testing.T) {
3542 var peerAdvertisedMaxHeaderTableSize uint32 = 16384
3543 tc := newTestClientConn(t, func(tr *Transport) {
3544 tr.MaxEncoderHeaderTableSize = 8192
3545 })
3546 tc.greet(Setting{SettingHeaderTableSize, peerAdvertisedMaxHeaderTableSize})
3547
3548 if got, want := tc.cc.henc.MaxDynamicTableSize(), tc.tr.MaxEncoderHeaderTableSize; got != want {
3549 t.Fatalf("henc.MaxDynamicTableSize() = %d, want %d", got, want)
3550 }
3551 }
3552
3553 func TestAuthorityAddr(t *testing.T) {
3554 tests := []struct {
3555 scheme, authority string
3556 want string
3557 }{
3558 {"http", "foo.com", "foo.com:80"},
3559 {"https", "foo.com", "foo.com:443"},
3560 {"https", "foo.com:", "foo.com:443"},
3561 {"https", "foo.com:1234", "foo.com:1234"},
3562 {"https", "1.2.3.4:1234", "1.2.3.4:1234"},
3563 {"https", "1.2.3.4", "1.2.3.4:443"},
3564 {"https", "1.2.3.4:", "1.2.3.4:443"},
3565 {"https", "[::1]:1234", "[::1]:1234"},
3566 {"https", "[::1]", "[::1]:443"},
3567 {"https", "[::1]:", "[::1]:443"},
3568 }
3569 for _, tt := range tests {
3570 got := authorityAddr(tt.scheme, tt.authority)
3571 if got != tt.want {
3572 t.Errorf("authorityAddr(%q, %q) = %q; want %q", tt.scheme, tt.authority, got, tt.want)
3573 }
3574 }
3575 }
3576
3577
3578
3579 func TestTransportAllocationsAfterResponseBodyClose(t *testing.T) {
3580 megabyteZero := make([]byte, 1<<20)
3581
3582 writeErr := make(chan error, 1)
3583
3584 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3585 w.(http.Flusher).Flush()
3586 var sum int64
3587 for i := 0; i < 100; i++ {
3588 n, err := w.Write(megabyteZero)
3589 sum += int64(n)
3590 if err != nil {
3591 writeErr <- err
3592 return
3593 }
3594 }
3595 t.Logf("wrote all %d bytes", sum)
3596 writeErr <- nil
3597 })
3598
3599 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3600 defer tr.CloseIdleConnections()
3601 c := &http.Client{Transport: tr}
3602 res, err := c.Get(ts.URL)
3603 if err != nil {
3604 t.Fatal(err)
3605 }
3606 var buf [1]byte
3607 if _, err := res.Body.Read(buf[:]); err != nil {
3608 t.Error(err)
3609 }
3610 if err := res.Body.Close(); err != nil {
3611 t.Error(err)
3612 }
3613
3614 trb, ok := res.Body.(transportResponseBody)
3615 if !ok {
3616 t.Fatalf("res.Body = %T; want transportResponseBody", res.Body)
3617 }
3618 if trb.cs.bufPipe.b != nil {
3619 t.Errorf("response body pipe is still open")
3620 }
3621
3622 gotErr := <-writeErr
3623 if gotErr == nil {
3624 t.Errorf("Handler unexpectedly managed to write its entire response without getting an error")
3625 } else if gotErr != errStreamClosed {
3626 t.Errorf("Handler Write err = %v; want errStreamClosed", gotErr)
3627 }
3628 }
3629
3630
3631
3632 func TestTransportNoBodyMeansNoDATA(t *testing.T) {
3633 tc := newTestClientConn(t)
3634 tc.greet()
3635
3636 req, _ := http.NewRequest("GET", "https://dummy.tld/", http.NoBody)
3637 rt := tc.roundTrip(req)
3638
3639 tc.wantHeaders(wantHeader{
3640 streamID: rt.streamID(),
3641 endStream: true,
3642 header: http.Header{
3643 ":authority": []string{"dummy.tld"},
3644 ":method": []string{"GET"},
3645 ":path": []string{"/"},
3646 },
3647 })
3648 if fr := tc.readFrame(); fr != nil {
3649 t.Fatalf("unexpected frame after headers: %v", fr)
3650 }
3651 }
3652
3653 func benchSimpleRoundTrip(b *testing.B, nReqHeaders, nResHeader int) {
3654 disableGoroutineTracking(b)
3655 b.ReportAllocs()
3656 ts := newTestServer(b,
3657 func(w http.ResponseWriter, r *http.Request) {
3658 for i := 0; i < nResHeader; i++ {
3659 name := fmt.Sprint("A-", i)
3660 w.Header().Set(name, "*")
3661 }
3662 },
3663 optQuiet,
3664 )
3665
3666 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3667 defer tr.CloseIdleConnections()
3668
3669 req, err := http.NewRequest("GET", ts.URL, nil)
3670 if err != nil {
3671 b.Fatal(err)
3672 }
3673
3674 for i := 0; i < nReqHeaders; i++ {
3675 name := fmt.Sprint("A-", i)
3676 req.Header.Set(name, "*")
3677 }
3678
3679 b.ResetTimer()
3680
3681 for i := 0; i < b.N; i++ {
3682 res, err := tr.RoundTrip(req)
3683 if err != nil {
3684 if res != nil {
3685 res.Body.Close()
3686 }
3687 b.Fatalf("RoundTrip err = %v; want nil", err)
3688 }
3689 res.Body.Close()
3690 if res.StatusCode != http.StatusOK {
3691 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3692 }
3693 }
3694 }
3695
3696 type infiniteReader struct{}
3697
3698 func (r infiniteReader) Read(b []byte) (int, error) {
3699 return len(b), nil
3700 }
3701
3702
3703
3704 func TestTransportResponseAndResetWithoutConsumingBodyRace(t *testing.T) {
3705 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3706 w.WriteHeader(http.StatusOK)
3707 })
3708
3709 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3710 defer tr.CloseIdleConnections()
3711
3712
3713 req, _ := http.NewRequest("PUT", ts.URL, infiniteReader{})
3714 res, err := tr.RoundTrip(req)
3715 if err != nil {
3716 t.Fatal(err)
3717 }
3718 if res.StatusCode != http.StatusOK {
3719 t.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3720 }
3721 }
3722
3723
3724
3725 func TestTransportHandlesInvalidStatuslessResponse(t *testing.T) {
3726 tc := newTestClientConn(t)
3727 tc.greet()
3728
3729 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
3730 rt := tc.roundTrip(req)
3731
3732 tc.wantFrameType(FrameHeaders)
3733 tc.writeHeaders(HeadersFrameParam{
3734 StreamID: rt.streamID(),
3735 EndHeaders: true,
3736 EndStream: false,
3737 BlockFragment: tc.makeHeaderBlockFragment(
3738 "content-type", "text/html",
3739 ),
3740 })
3741 tc.writeData(rt.streamID(), true, []byte("payload"))
3742 }
3743
3744 func BenchmarkClientRequestHeaders(b *testing.B) {
3745 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
3746 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 10, 0) })
3747 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100, 0) })
3748 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000, 0) })
3749 }
3750
3751 func BenchmarkClientResponseHeaders(b *testing.B) {
3752 b.Run(" 0 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 0) })
3753 b.Run(" 10 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 10) })
3754 b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 100) })
3755 b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 0, 1000) })
3756 }
3757
3758 func BenchmarkDownloadFrameSize(b *testing.B) {
3759 b.Run(" 16k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 16*1024) })
3760 b.Run(" 64k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 64*1024) })
3761 b.Run("128k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 128*1024) })
3762 b.Run("256k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 256*1024) })
3763 b.Run("512k Frame", func(b *testing.B) { benchLargeDownloadRoundTrip(b, 512*1024) })
3764 }
3765 func benchLargeDownloadRoundTrip(b *testing.B, frameSize uint32) {
3766 disableGoroutineTracking(b)
3767 const transferSize = 1024 * 1024 * 1024
3768 b.ReportAllocs()
3769 ts := newTestServer(b,
3770 func(w http.ResponseWriter, r *http.Request) {
3771
3772 w.Header().Set("Content-Length", strconv.Itoa(transferSize))
3773 w.Header().Set("Content-Transfer-Encoding", "binary")
3774 var data [1024 * 1024]byte
3775 for i := 0; i < transferSize/(1024*1024); i++ {
3776 w.Write(data[:])
3777 }
3778 }, optQuiet,
3779 )
3780
3781 tr := &Transport{TLSClientConfig: tlsConfigInsecure, MaxReadFrameSize: frameSize}
3782 defer tr.CloseIdleConnections()
3783
3784 req, err := http.NewRequest("GET", ts.URL, nil)
3785 if err != nil {
3786 b.Fatal(err)
3787 }
3788
3789 b.N = 3
3790 b.SetBytes(transferSize)
3791 b.ResetTimer()
3792
3793 for i := 0; i < b.N; i++ {
3794 res, err := tr.RoundTrip(req)
3795 if err != nil {
3796 if res != nil {
3797 res.Body.Close()
3798 }
3799 b.Fatalf("RoundTrip err = %v; want nil", err)
3800 }
3801 data, _ := io.ReadAll(res.Body)
3802 if len(data) != transferSize {
3803 b.Fatalf("Response length invalid")
3804 }
3805 res.Body.Close()
3806 if res.StatusCode != http.StatusOK {
3807 b.Fatalf("Response code = %v; want %v", res.StatusCode, http.StatusOK)
3808 }
3809 }
3810 }
3811
3812 func activeStreams(cc *ClientConn) int {
3813 count := 0
3814 cc.mu.Lock()
3815 defer cc.mu.Unlock()
3816 for _, cs := range cc.streams {
3817 select {
3818 case <-cs.abort:
3819 default:
3820 count++
3821 }
3822 }
3823 return count
3824 }
3825
3826 type closeMode int
3827
3828 const (
3829 closeAtHeaders closeMode = iota
3830 closeAtBody
3831 shutdown
3832 shutdownCancel
3833 )
3834
3835
3836 func testClientConnClose(t *testing.T, closeMode closeMode) {
3837 clientDone := make(chan struct{})
3838 defer close(clientDone)
3839 handlerDone := make(chan struct{})
3840 closeDone := make(chan struct{})
3841 beforeHeader := func() {}
3842 bodyWrite := func(w http.ResponseWriter) {}
3843 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3844 defer close(handlerDone)
3845 beforeHeader()
3846 w.WriteHeader(http.StatusOK)
3847 w.(http.Flusher).Flush()
3848 bodyWrite(w)
3849 select {
3850 case <-w.(http.CloseNotifier).CloseNotify():
3851
3852 if closeMode == shutdown || closeMode == shutdownCancel {
3853 t.Error("expected request to complete")
3854 }
3855 case <-clientDone:
3856 if closeMode == closeAtHeaders || closeMode == closeAtBody {
3857 t.Error("expected connection closed by client")
3858 }
3859 }
3860 })
3861 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3862 defer tr.CloseIdleConnections()
3863 ctx := context.Background()
3864 cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
3865 req, err := http.NewRequest("GET", ts.URL, nil)
3866 if err != nil {
3867 t.Fatal(err)
3868 }
3869 if closeMode == closeAtHeaders {
3870 beforeHeader = func() {
3871 if err := cc.Close(); err != nil {
3872 t.Error(err)
3873 }
3874 close(closeDone)
3875 }
3876 }
3877 var sendBody chan struct{}
3878 if closeMode == closeAtBody {
3879 sendBody = make(chan struct{})
3880 bodyWrite = func(w http.ResponseWriter) {
3881 <-sendBody
3882 b := make([]byte, 32)
3883 w.Write(b)
3884 w.(http.Flusher).Flush()
3885 if err := cc.Close(); err != nil {
3886 t.Errorf("unexpected ClientConn close error: %v", err)
3887 }
3888 close(closeDone)
3889 w.Write(b)
3890 w.(http.Flusher).Flush()
3891 }
3892 }
3893 res, err := cc.RoundTrip(req)
3894 if res != nil {
3895 defer res.Body.Close()
3896 }
3897 if closeMode == closeAtHeaders {
3898 got := fmt.Sprint(err)
3899 want := "http2: client connection force closed via ClientConn.Close"
3900 if got != want {
3901 t.Fatalf("RoundTrip error = %v, want %v", got, want)
3902 }
3903 } else {
3904 if err != nil {
3905 t.Fatalf("RoundTrip: %v", err)
3906 }
3907 if got, want := activeStreams(cc), 1; got != want {
3908 t.Errorf("got %d active streams, want %d", got, want)
3909 }
3910 }
3911 switch closeMode {
3912 case shutdownCancel:
3913 if err = cc.Shutdown(canceledCtx); err != context.Canceled {
3914 t.Errorf("got %v, want %v", err, context.Canceled)
3915 }
3916 if cc.closing == false {
3917 t.Error("expected closing to be true")
3918 }
3919 if cc.CanTakeNewRequest() == true {
3920 t.Error("CanTakeNewRequest to return false")
3921 }
3922 if v, want := len(cc.streams), 1; v != want {
3923 t.Errorf("expected %d active streams, got %d", want, v)
3924 }
3925 clientDone <- struct{}{}
3926 <-handlerDone
3927 case shutdown:
3928 wait := make(chan struct{})
3929 shutdownEnterWaitStateHook = func() {
3930 close(wait)
3931 shutdownEnterWaitStateHook = func() {}
3932 }
3933 defer func() { shutdownEnterWaitStateHook = func() {} }()
3934 shutdown := make(chan struct{}, 1)
3935 go func() {
3936 if err = cc.Shutdown(context.Background()); err != nil {
3937 t.Error(err)
3938 }
3939 close(shutdown)
3940 }()
3941
3942 <-wait
3943 cc.mu.Lock()
3944 if cc.closing == false {
3945 t.Error("expected closing to be true")
3946 }
3947 cc.mu.Unlock()
3948 if cc.CanTakeNewRequest() == true {
3949 t.Error("CanTakeNewRequest to return false")
3950 }
3951 if got, want := activeStreams(cc), 1; got != want {
3952 t.Errorf("got %d active streams, want %d", got, want)
3953 }
3954
3955 clientDone <- struct{}{}
3956
3957 select {
3958 case <-shutdown:
3959 case <-time.After(2 * time.Second):
3960 t.Fatal("expected server connection to close")
3961 }
3962 case closeAtHeaders, closeAtBody:
3963 if closeMode == closeAtBody {
3964 go close(sendBody)
3965 if _, err := io.Copy(io.Discard, res.Body); err == nil {
3966 t.Error("expected a Copy error, got nil")
3967 }
3968 }
3969 <-closeDone
3970 if got, want := activeStreams(cc), 0; got != want {
3971 t.Errorf("got %d active streams, want %d", got, want)
3972 }
3973
3974 select {
3975 case <-handlerDone:
3976 case <-time.After(2 * time.Second):
3977 t.Fatal("expected server connection to close")
3978 }
3979 }
3980 }
3981
3982
3983
3984
3985 func TestClientConnCloseAtHeaders(t *testing.T) {
3986 testClientConnClose(t, closeAtHeaders)
3987 }
3988
3989
3990
3991 func TestClientConnCloseAtBody(t *testing.T) {
3992 testClientConnClose(t, closeAtBody)
3993 }
3994
3995
3996
3997 func TestClientConnShutdown(t *testing.T) {
3998 testClientConnClose(t, shutdown)
3999 }
4000
4001
4002
4003
4004 func TestClientConnShutdownCancel(t *testing.T) {
4005 testClientConnClose(t, shutdownCancel)
4006 }
4007
4008
4009
4010
4011
4012
4013 func TestTransportUsesGetBodyWhenPresent(t *testing.T) {
4014 calls := 0
4015 someBody := func() io.ReadCloser {
4016 return struct{ io.ReadCloser }{io.NopCloser(bytes.NewReader(nil))}
4017 }
4018 req := &http.Request{
4019 Body: someBody(),
4020 GetBody: func() (io.ReadCloser, error) {
4021 calls++
4022 return someBody(), nil
4023 },
4024 }
4025
4026 req2, err := shouldRetryRequest(req, errClientConnUnusable)
4027 if err != nil {
4028 t.Fatal(err)
4029 }
4030 if calls != 1 {
4031 t.Errorf("Calls = %d; want 1", calls)
4032 }
4033 if req2 == req {
4034 t.Error("req2 changed")
4035 }
4036 if req2 == nil {
4037 t.Fatal("req2 is nil")
4038 }
4039 if req2.Body == nil {
4040 t.Fatal("req2.Body is nil")
4041 }
4042 if req2.GetBody == nil {
4043 t.Fatal("req2.GetBody is nil")
4044 }
4045 if req2.Body == req.Body {
4046 t.Error("req2.Body unchanged")
4047 }
4048 }
4049
4050
4051
4052 func TestNoDialH2RoundTripperType(t *testing.T) {
4053 t1 := new(http.Transport)
4054 t2 := new(Transport)
4055 rt := noDialH2RoundTripper{t2}
4056 if err := registerHTTPSProtocol(t1, rt); err != nil {
4057 t.Fatal(err)
4058 }
4059 rv := reflect.ValueOf(rt)
4060 if rv.Type().Kind() != reflect.Struct {
4061 t.Fatalf("kind = %v; net/http expects struct", rv.Type().Kind())
4062 }
4063 if n := rv.Type().NumField(); n != 1 {
4064 t.Fatalf("fields = %d; net/http expects 1", n)
4065 }
4066 v := rv.Field(0)
4067 if _, ok := v.Interface().(*Transport); !ok {
4068 t.Fatalf("wrong kind %T; want *Transport", v.Interface())
4069 }
4070 }
4071
4072 type errReader struct {
4073 body []byte
4074 err error
4075 }
4076
4077 func (r *errReader) Read(p []byte) (int, error) {
4078 if len(r.body) > 0 {
4079 n := copy(p, r.body)
4080 r.body = r.body[n:]
4081 return n, nil
4082 }
4083 return 0, r.err
4084 }
4085
4086 func testTransportBodyReadError(t *testing.T, body []byte) {
4087 tc := newTestClientConn(t)
4088 tc.greet()
4089
4090 bodyReadError := errors.New("body read error")
4091 b := tc.newRequestBody()
4092 b.Write(body)
4093 b.closeWithError(bodyReadError)
4094 req, _ := http.NewRequest("PUT", "https://dummy.tld/", b)
4095 rt := tc.roundTrip(req)
4096
4097 tc.wantFrameType(FrameHeaders)
4098 var receivedBody []byte
4099 readFrames:
4100 for {
4101 switch f := tc.readFrame().(type) {
4102 case *DataFrame:
4103 receivedBody = append(receivedBody, f.Data()...)
4104 case *RSTStreamFrame:
4105 break readFrames
4106 default:
4107 t.Fatalf("unexpected frame: %v", f)
4108 case nil:
4109 t.Fatalf("transport is idle, want RST_STREAM")
4110 }
4111 }
4112 if !bytes.Equal(receivedBody, body) {
4113 t.Fatalf("body: %q; expected %q", receivedBody, body)
4114 }
4115
4116 if err := rt.err(); err != bodyReadError {
4117 t.Fatalf("err = %v; want %v", err, bodyReadError)
4118 }
4119
4120 if got := activeStreams(tc.cc); got != 0 {
4121 t.Fatalf("active streams count: %v; want 0", got)
4122 }
4123 }
4124
4125 func TestTransportBodyReadError_Immediately(t *testing.T) { testTransportBodyReadError(t, nil) }
4126 func TestTransportBodyReadError_Some(t *testing.T) { testTransportBodyReadError(t, []byte("123")) }
4127
4128
4129
4130
4131 func TestTransportBodyEagerEndStream(t *testing.T) {
4132 const reqBody = "some request body"
4133 const resBody = "some response body"
4134
4135 tc := newTestClientConn(t)
4136 tc.greet()
4137
4138 body := strings.NewReader(reqBody)
4139 req, _ := http.NewRequest("PUT", "https://dummy.tld/", body)
4140 tc.roundTrip(req)
4141
4142 tc.wantFrameType(FrameHeaders)
4143 f := readFrame[*DataFrame](t, tc)
4144 if !f.StreamEnded() {
4145 t.Fatalf("data frame without END_STREAM %v", f)
4146 }
4147 }
4148
4149 type chunkReader struct {
4150 chunks [][]byte
4151 }
4152
4153 func (r *chunkReader) Read(p []byte) (int, error) {
4154 if len(r.chunks) > 0 {
4155 n := copy(p, r.chunks[0])
4156 r.chunks = r.chunks[1:]
4157 return n, nil
4158 }
4159 panic("shouldn't read this many times")
4160 }
4161
4162
4163
4164
4165
4166
4167
4168
4169
4170 func TestTransportBodyLargerThanSpecifiedContentLength_len3(t *testing.T) {
4171 body := &chunkReader{[][]byte{
4172 []byte("123"),
4173 []byte("456"),
4174 }}
4175 testTransportBodyLargerThanSpecifiedContentLength(t, body, 3)
4176 }
4177
4178 func TestTransportBodyLargerThanSpecifiedContentLength_len2(t *testing.T) {
4179 body := &chunkReader{[][]byte{
4180 []byte("123"),
4181 }}
4182 testTransportBodyLargerThanSpecifiedContentLength(t, body, 2)
4183 }
4184
4185 func testTransportBodyLargerThanSpecifiedContentLength(t *testing.T, body *chunkReader, contentLen int64) {
4186 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4187 r.Body.Read(make([]byte, 6))
4188 })
4189
4190 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4191 defer tr.CloseIdleConnections()
4192
4193 req, _ := http.NewRequest("POST", ts.URL, body)
4194 req.ContentLength = contentLen
4195 _, err := tr.RoundTrip(req)
4196 if err != errReqBodyTooLong {
4197 t.Fatalf("expected %v, got %v", errReqBodyTooLong, err)
4198 }
4199 }
4200
4201 func TestClientConnTooIdle(t *testing.T) {
4202 tests := []struct {
4203 cc func() *ClientConn
4204 want bool
4205 }{
4206 {
4207 func() *ClientConn {
4208 return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
4209 },
4210 true,
4211 },
4212 {
4213 func() *ClientConn {
4214 return &ClientConn{idleTimeout: 5 * time.Second, lastIdle: time.Time{}}
4215 },
4216 false,
4217 },
4218 {
4219 func() *ClientConn {
4220 return &ClientConn{idleTimeout: 60 * time.Second, lastIdle: time.Now().Add(-10 * time.Second)}
4221 },
4222 false,
4223 },
4224 {
4225 func() *ClientConn {
4226 return &ClientConn{idleTimeout: 0, lastIdle: time.Now().Add(-10 * time.Second)}
4227 },
4228 false,
4229 },
4230 }
4231 for i, tt := range tests {
4232 got := tt.cc().tooIdleLocked()
4233 if got != tt.want {
4234 t.Errorf("%d. got %v; want %v", i, got, tt.want)
4235 }
4236 }
4237 }
4238
4239 type fakeConnErr struct {
4240 net.Conn
4241 writeErr error
4242 closed bool
4243 }
4244
4245 func (fce *fakeConnErr) Write(b []byte) (n int, err error) {
4246 return 0, fce.writeErr
4247 }
4248
4249 func (fce *fakeConnErr) Close() error {
4250 fce.closed = true
4251 return nil
4252 }
4253
4254
4255 func TestTransportNewClientConnCloseOnWriteError(t *testing.T) {
4256 tr := &Transport{}
4257 writeErr := errors.New("write error")
4258 fakeConn := &fakeConnErr{writeErr: writeErr}
4259 _, err := tr.NewClientConn(fakeConn)
4260 if err != writeErr {
4261 t.Fatalf("expected %v, got %v", writeErr, err)
4262 }
4263 if !fakeConn.closed {
4264 t.Error("expected closed conn")
4265 }
4266 }
4267
4268 func TestTransportRoundtripCloseOnWriteError(t *testing.T) {
4269 req, err := http.NewRequest("GET", "https://dummy.tld/", nil)
4270 if err != nil {
4271 t.Fatal(err)
4272 }
4273 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {})
4274
4275 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4276 defer tr.CloseIdleConnections()
4277 ctx := context.Background()
4278 cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
4279 if err != nil {
4280 t.Fatal(err)
4281 }
4282
4283 writeErr := errors.New("write error")
4284 cc.wmu.Lock()
4285 cc.werr = writeErr
4286 cc.wmu.Unlock()
4287
4288 _, err = cc.RoundTrip(req)
4289 if err != writeErr {
4290 t.Fatalf("expected %v, got %v", writeErr, err)
4291 }
4292
4293 cc.mu.Lock()
4294 closed := cc.closed
4295 cc.mu.Unlock()
4296 if !closed {
4297 t.Fatal("expected closed")
4298 }
4299 }
4300
4301
4302
4303
4304 func TestTransportBodyRewindRace(t *testing.T) {
4305 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4306 w.Header().Set("Connection", "close")
4307 w.WriteHeader(http.StatusOK)
4308 return
4309 })
4310
4311 tr := &http.Transport{
4312 TLSClientConfig: tlsConfigInsecure,
4313 MaxConnsPerHost: 1,
4314 }
4315 err := ConfigureTransport(tr)
4316 if err != nil {
4317 t.Fatal(err)
4318 }
4319 client := &http.Client{
4320 Transport: tr,
4321 }
4322
4323 const clients = 50
4324
4325 var wg sync.WaitGroup
4326 wg.Add(clients)
4327 for i := 0; i < clients; i++ {
4328 req, err := http.NewRequest("POST", ts.URL, bytes.NewBufferString("abcdef"))
4329 if err != nil {
4330 t.Fatalf("unexpect new request error: %v", err)
4331 }
4332
4333 go func() {
4334 defer wg.Done()
4335 res, err := client.Do(req)
4336 if err == nil {
4337 res.Body.Close()
4338 }
4339 }()
4340 }
4341
4342 wg.Wait()
4343 }
4344
4345
4346
4347 func TestTransportServerResetStreamAtHeaders(t *testing.T) {
4348 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4349 w.WriteHeader(http.StatusUnauthorized)
4350 return
4351 })
4352
4353 tr := &http.Transport{
4354 TLSClientConfig: tlsConfigInsecure,
4355 MaxConnsPerHost: 1,
4356 ExpectContinueTimeout: 10 * time.Second,
4357 }
4358
4359 err := ConfigureTransport(tr)
4360 if err != nil {
4361 t.Fatal(err)
4362 }
4363 client := &http.Client{
4364 Transport: tr,
4365 }
4366
4367 req, err := http.NewRequest("POST", ts.URL, errorReader{io.EOF})
4368 if err != nil {
4369 t.Fatalf("unexpect new request error: %v", err)
4370 }
4371 req.ContentLength = 0
4372 req.Header.Set("Expect", "100-continue")
4373 res, err := client.Do(req)
4374 if err != nil {
4375 t.Fatal(err)
4376 }
4377 res.Body.Close()
4378 }
4379
4380 type trackingReader struct {
4381 rdr io.Reader
4382 wasRead uint32
4383 }
4384
4385 func (tr *trackingReader) Read(p []byte) (int, error) {
4386 atomic.StoreUint32(&tr.wasRead, 1)
4387 return tr.rdr.Read(p)
4388 }
4389
4390 func (tr *trackingReader) WasRead() bool {
4391 return atomic.LoadUint32(&tr.wasRead) != 0
4392 }
4393
4394 func TestTransportExpectContinue(t *testing.T) {
4395 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4396 switch r.URL.Path {
4397 case "/reject":
4398 w.WriteHeader(403)
4399 default:
4400 io.Copy(io.Discard, r.Body)
4401 }
4402 })
4403
4404 tr := &http.Transport{
4405 TLSClientConfig: tlsConfigInsecure,
4406 MaxConnsPerHost: 1,
4407 ExpectContinueTimeout: 10 * time.Second,
4408 }
4409
4410 err := ConfigureTransport(tr)
4411 if err != nil {
4412 t.Fatal(err)
4413 }
4414 client := &http.Client{
4415 Transport: tr,
4416 }
4417
4418 testCases := []struct {
4419 Name string
4420 Path string
4421 Body *trackingReader
4422 ExpectedCode int
4423 ShouldRead bool
4424 }{
4425 {
4426 Name: "read-all",
4427 Path: "/",
4428 Body: &trackingReader{rdr: strings.NewReader("hello")},
4429 ExpectedCode: 200,
4430 ShouldRead: true,
4431 },
4432 {
4433 Name: "reject",
4434 Path: "/reject",
4435 Body: &trackingReader{rdr: strings.NewReader("hello")},
4436 ExpectedCode: 403,
4437 ShouldRead: false,
4438 },
4439 }
4440
4441 for _, tc := range testCases {
4442 t.Run(tc.Name, func(t *testing.T) {
4443 startTime := time.Now()
4444
4445 req, err := http.NewRequest("POST", ts.URL+tc.Path, tc.Body)
4446 if err != nil {
4447 t.Fatal(err)
4448 }
4449 req.Header.Set("Expect", "100-continue")
4450 res, err := client.Do(req)
4451 if err != nil {
4452 t.Fatal(err)
4453 }
4454 res.Body.Close()
4455
4456 if delta := time.Since(startTime); delta >= tr.ExpectContinueTimeout {
4457 t.Error("Request didn't finish before expect continue timeout")
4458 }
4459 if res.StatusCode != tc.ExpectedCode {
4460 t.Errorf("Unexpected status code, got %d, expected %d", res.StatusCode, tc.ExpectedCode)
4461 }
4462 if tc.Body.WasRead() != tc.ShouldRead {
4463 t.Errorf("Unexpected read status, got %v, expected %v", tc.Body.WasRead(), tc.ShouldRead)
4464 }
4465 })
4466 }
4467 }
4468
4469 type closeChecker struct {
4470 io.ReadCloser
4471 closed chan struct{}
4472 }
4473
4474 func newCloseChecker(r io.ReadCloser) *closeChecker {
4475 return &closeChecker{r, make(chan struct{})}
4476 }
4477
4478 func newStaticCloseChecker(body string) *closeChecker {
4479 return newCloseChecker(io.NopCloser(strings.NewReader("body")))
4480 }
4481
4482 func (rc *closeChecker) Read(b []byte) (n int, err error) {
4483 select {
4484 default:
4485 case <-rc.closed:
4486
4487
4488
4489 return 0, errors.New("read after Body.Close")
4490 }
4491 return rc.ReadCloser.Read(b)
4492 }
4493
4494 func (rc *closeChecker) Close() error {
4495 close(rc.closed)
4496 return rc.ReadCloser.Close()
4497 }
4498
4499 func (rc *closeChecker) isClosed() error {
4500
4501
4502
4503 timeout := time.Duration(10 * time.Second)
4504 select {
4505 case <-rc.closed:
4506 case <-time.After(timeout):
4507 return fmt.Errorf("body not closed after %v", timeout)
4508 }
4509 return nil
4510 }
4511
4512
4513 type blockingWriteConn struct {
4514 net.Conn
4515 writeOnce sync.Once
4516 writec chan struct{}
4517 unblockc chan struct{}
4518 count, limit int
4519 }
4520
4521 func newBlockingWriteConn(conn net.Conn, limit int) *blockingWriteConn {
4522 return &blockingWriteConn{
4523 Conn: conn,
4524 limit: limit,
4525 writec: make(chan struct{}),
4526 unblockc: make(chan struct{}),
4527 }
4528 }
4529
4530
4531 func (c *blockingWriteConn) wait() {
4532 <-c.writec
4533 }
4534
4535
4536 func (c *blockingWriteConn) unblock() {
4537 close(c.unblockc)
4538 }
4539
4540 func (c *blockingWriteConn) Write(b []byte) (n int, err error) {
4541 if c.count+len(b) > c.limit {
4542 c.writeOnce.Do(func() {
4543 close(c.writec)
4544 })
4545 <-c.unblockc
4546 }
4547 n, err = c.Conn.Write(b)
4548 c.count += n
4549 return n, err
4550 }
4551
4552
4553
4554 func TestTransportFrameBufferReuse(t *testing.T) {
4555 filler := hex.EncodeToString([]byte(randString(2048)))
4556
4557 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4558 if got, want := r.Header.Get("Big"), filler; got != want {
4559 t.Errorf(`r.Header.Get("Big") = %q, want %q`, got, want)
4560 }
4561 b, err := io.ReadAll(r.Body)
4562 if err != nil {
4563 t.Errorf("error reading request body: %v", err)
4564 }
4565 if got, want := string(b), filler; got != want {
4566 t.Errorf("request body = %q, want %q", got, want)
4567 }
4568 if got, want := r.Trailer.Get("Big"), filler; got != want {
4569 t.Errorf(`r.Trailer.Get("Big") = %q, want %q`, got, want)
4570 }
4571 })
4572
4573 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4574 defer tr.CloseIdleConnections()
4575
4576 var wg sync.WaitGroup
4577 defer wg.Wait()
4578 for i := 0; i < 10; i++ {
4579 wg.Add(1)
4580 go func() {
4581 defer wg.Done()
4582 req, err := http.NewRequest("POST", ts.URL, strings.NewReader(filler))
4583 if err != nil {
4584 t.Error(err)
4585 return
4586 }
4587 req.Header.Set("Big", filler)
4588 req.Trailer = make(http.Header)
4589 req.Trailer.Set("Big", filler)
4590 res, err := tr.RoundTrip(req)
4591 if err != nil {
4592 t.Error(err)
4593 return
4594 }
4595 if got, want := res.StatusCode, 200; got != want {
4596 t.Errorf("StatusCode = %v; want %v", got, want)
4597 }
4598 if res != nil && res.Body != nil {
4599 res.Body.Close()
4600 }
4601 }()
4602 }
4603
4604 }
4605
4606
4607
4608
4609
4610 func TestTransportBlockingRequestWrite(t *testing.T) {
4611 filler := hex.EncodeToString([]byte(randString(2048)))
4612 for _, test := range []struct {
4613 name string
4614 req func(url string) (*http.Request, error)
4615 }{{
4616 name: "headers",
4617 req: func(url string) (*http.Request, error) {
4618 req, err := http.NewRequest("POST", url, nil)
4619 if err != nil {
4620 return nil, err
4621 }
4622 req.Header.Set("Big", filler)
4623 return req, err
4624 },
4625 }, {
4626 name: "body",
4627 req: func(url string) (*http.Request, error) {
4628 req, err := http.NewRequest("POST", url, strings.NewReader(filler))
4629 if err != nil {
4630 return nil, err
4631 }
4632 return req, err
4633 },
4634 }, {
4635 name: "trailer",
4636 req: func(url string) (*http.Request, error) {
4637 req, err := http.NewRequest("POST", url, strings.NewReader("body"))
4638 if err != nil {
4639 return nil, err
4640 }
4641 req.Trailer = make(http.Header)
4642 req.Trailer.Set("Big", filler)
4643 return req, err
4644 },
4645 }} {
4646 test := test
4647 t.Run(test.name, func(t *testing.T) {
4648 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4649 if v := r.Header.Get("Big"); v != "" && v != filler {
4650 t.Errorf("request header mismatch")
4651 }
4652 if v, _ := io.ReadAll(r.Body); len(v) != 0 && string(v) != "body" && string(v) != filler {
4653 t.Errorf("request body mismatch\ngot: %q\nwant: %q", string(v), filler)
4654 }
4655 if v := r.Trailer.Get("Big"); v != "" && v != filler {
4656 t.Errorf("request trailer mismatch\ngot: %q\nwant: %q", string(v), filler)
4657 }
4658 }, func(s *Server) {
4659 s.MaxConcurrentStreams = 1
4660 })
4661
4662
4663 connc := make(chan *blockingWriteConn, 1)
4664 connCount := 0
4665 tr := &Transport{
4666 TLSClientConfig: tlsConfigInsecure,
4667 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
4668 connCount++
4669 c, err := tls.Dial(network, addr, cfg)
4670 wc := newBlockingWriteConn(c, 1024)
4671 select {
4672 case connc <- wc:
4673 default:
4674 }
4675 return wc, err
4676 },
4677 }
4678 defer tr.CloseIdleConnections()
4679
4680
4681 {
4682 req, err := http.NewRequest("POST", ts.URL, nil)
4683 if err != nil {
4684 t.Fatal(err)
4685 }
4686 res, err := tr.RoundTrip(req)
4687 if err != nil {
4688 t.Fatal(err)
4689 }
4690 if got, want := res.StatusCode, 200; got != want {
4691 t.Errorf("StatusCode = %v; want %v", got, want)
4692 }
4693 if res != nil && res.Body != nil {
4694 res.Body.Close()
4695 }
4696 }
4697
4698
4699 reqc := make(chan struct{})
4700 go func() {
4701 defer close(reqc)
4702 req, err := test.req(ts.URL)
4703 if err != nil {
4704 t.Error(err)
4705 return
4706 }
4707 res, _ := tr.RoundTrip(req)
4708 if res != nil && res.Body != nil {
4709 res.Body.Close()
4710 }
4711 }()
4712 conn := <-connc
4713 conn.wait()
4714
4715
4716
4717 {
4718 req, err := http.NewRequest("POST", ts.URL, nil)
4719 if err != nil {
4720 t.Fatal(err)
4721 }
4722 res, err := tr.RoundTrip(req)
4723 if err != nil {
4724 t.Fatal(err)
4725 }
4726 if got, want := res.StatusCode, 200; got != want {
4727 t.Errorf("StatusCode = %v; want %v", got, want)
4728 }
4729 if res != nil && res.Body != nil {
4730 res.Body.Close()
4731 }
4732 }
4733
4734
4735 select {
4736 case <-reqc:
4737 t.Errorf("request 2 unexpectedly completed")
4738 default:
4739 }
4740
4741 conn.unblock()
4742 <-reqc
4743
4744 if connCount != 2 {
4745 t.Errorf("created %v connections, want 1", connCount)
4746 }
4747 })
4748 }
4749 }
4750
4751 func TestTransportCloseRequestBody(t *testing.T) {
4752 var statusCode int
4753 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4754 w.WriteHeader(statusCode)
4755 })
4756
4757 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4758 defer tr.CloseIdleConnections()
4759 ctx := context.Background()
4760 cc, err := tr.dialClientConn(ctx, ts.Listener.Addr().String(), false)
4761 if err != nil {
4762 t.Fatal(err)
4763 }
4764
4765 for _, status := range []int{200, 401} {
4766 t.Run(fmt.Sprintf("status=%d", status), func(t *testing.T) {
4767 statusCode = status
4768 pr, pw := io.Pipe()
4769 body := newCloseChecker(pr)
4770 req, err := http.NewRequest("PUT", "https://dummy.tld/", body)
4771 if err != nil {
4772 t.Fatal(err)
4773 }
4774 res, err := cc.RoundTrip(req)
4775 if err != nil {
4776 t.Fatal(err)
4777 }
4778 res.Body.Close()
4779 pw.Close()
4780 if err := body.isClosed(); err != nil {
4781 t.Fatal(err)
4782 }
4783 })
4784 }
4785 }
4786
4787 func TestTransportRetriesOnStreamProtocolError(t *testing.T) {
4788
4789
4790
4791
4792
4793 tt := newTestTransport(t)
4794
4795
4796
4797
4798
4799
4800 req1, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4801 rt1 := tt.roundTrip(req1)
4802 tc1 := tt.getConn()
4803 tc1.wantFrameType(FrameSettings)
4804 tc1.wantFrameType(FrameWindowUpdate)
4805 tc1.wantHeaders(wantHeader{
4806 streamID: 1,
4807 endStream: true,
4808 })
4809 tc1.writeSettings()
4810 tc1.wantFrameType(FrameSettings)
4811
4812
4813 req2, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4814 rt2 := tt.roundTrip(req2)
4815 tc1.wantHeaders(wantHeader{
4816 streamID: 3,
4817 endStream: true,
4818 })
4819
4820
4821 tc1.writeRSTStream(3, ErrCodeProtocol)
4822 if rt1.done() {
4823 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #1 is done; want still in progress")
4824 }
4825 if rt2.done() {
4826 t.Fatalf("After protocol error on RoundTrip #2, RoundTrip #2 is done; want still in progress")
4827 }
4828
4829
4830 tc2 := tt.getConn()
4831 tc2.wantFrameType(FrameSettings)
4832 tc2.wantFrameType(FrameWindowUpdate)
4833 tc2.wantHeaders(wantHeader{
4834 streamID: 1,
4835 endStream: true,
4836 })
4837 tc2.writeSettings()
4838 tc2.wantFrameType(FrameSettings)
4839
4840
4841 tc2.writeHeaders(HeadersFrameParam{
4842 StreamID: 1,
4843 EndHeaders: true,
4844 EndStream: true,
4845 BlockFragment: tc1.makeHeaderBlockFragment(
4846 ":status", "201",
4847 ),
4848 })
4849 rt2.wantStatus(201)
4850
4851
4852 tc1.writeHeaders(HeadersFrameParam{
4853 StreamID: 1,
4854 EndHeaders: true,
4855 EndStream: true,
4856 BlockFragment: tc1.makeHeaderBlockFragment(
4857 ":status", "200",
4858 ),
4859 })
4860 rt1.wantStatus(200)
4861 }
4862
4863 func TestClientConnReservations(t *testing.T) {
4864 tc := newTestClientConn(t)
4865 tc.greet(
4866 Setting{ID: SettingMaxConcurrentStreams, Val: initialMaxConcurrentStreams},
4867 )
4868
4869 doRoundTrip := func() {
4870 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
4871 rt := tc.roundTrip(req)
4872 tc.wantFrameType(FrameHeaders)
4873 tc.writeHeaders(HeadersFrameParam{
4874 StreamID: rt.streamID(),
4875 EndHeaders: true,
4876 EndStream: true,
4877 BlockFragment: tc.makeHeaderBlockFragment(
4878 ":status", "200",
4879 ),
4880 })
4881 rt.wantStatus(200)
4882 }
4883
4884 n := 0
4885 for n <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
4886 n++
4887 }
4888 if n != initialMaxConcurrentStreams {
4889 t.Errorf("did %v reservations; want %v", n, initialMaxConcurrentStreams)
4890 }
4891 doRoundTrip()
4892 n2 := 0
4893 for n2 <= 5 && tc.cc.ReserveNewRequest() {
4894 n2++
4895 }
4896 if n2 != 1 {
4897 t.Fatalf("after one RoundTrip, did %v reservations; want 1", n2)
4898 }
4899
4900
4901 for i := 0; i < n; i++ {
4902 doRoundTrip()
4903 }
4904
4905 n2 = 0
4906 for n2 <= initialMaxConcurrentStreams && tc.cc.ReserveNewRequest() {
4907 n2++
4908 }
4909 if n2 != n {
4910 t.Errorf("after reset, reservations = %v; want %v", n2, n)
4911 }
4912 }
4913
4914 func TestTransportTimeoutServerHangs(t *testing.T) {
4915 tc := newTestClientConn(t)
4916 tc.greet()
4917
4918 ctx, cancel := context.WithCancel(context.Background())
4919 req, _ := http.NewRequestWithContext(ctx, "PUT", "https://dummy.tld/", nil)
4920 rt := tc.roundTrip(req)
4921
4922 tc.wantFrameType(FrameHeaders)
4923 tc.advance(5 * time.Second)
4924 if f := tc.readFrame(); f != nil {
4925 t.Fatalf("unexpected frame: %v", f)
4926 }
4927 if rt.done() {
4928 t.Fatalf("after 5 seconds with no response, RoundTrip unexpectedly returned")
4929 }
4930
4931 cancel()
4932 tc.sync()
4933 if rt.err() != context.Canceled {
4934 t.Fatalf("RoundTrip error: %v; want context.Canceled", rt.err())
4935 }
4936 }
4937
4938 func TestTransportContentLengthWithoutBody(t *testing.T) {
4939 contentLength := ""
4940 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4941 w.Header().Set("Content-Length", contentLength)
4942 })
4943 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4944 defer tr.CloseIdleConnections()
4945
4946 for _, test := range []struct {
4947 name string
4948 contentLength string
4949 wantBody string
4950 wantErr error
4951 wantContentLength int64
4952 }{
4953 {
4954 name: "non-zero content length",
4955 contentLength: "42",
4956 wantErr: io.ErrUnexpectedEOF,
4957 wantContentLength: 42,
4958 },
4959 {
4960 name: "zero content length",
4961 contentLength: "0",
4962 wantErr: nil,
4963 wantContentLength: 0,
4964 },
4965 } {
4966 t.Run(test.name, func(t *testing.T) {
4967 contentLength = test.contentLength
4968
4969 req, _ := http.NewRequest("GET", ts.URL, nil)
4970 res, err := tr.RoundTrip(req)
4971 if err != nil {
4972 t.Fatal(err)
4973 }
4974 defer res.Body.Close()
4975 body, err := io.ReadAll(res.Body)
4976
4977 if err != test.wantErr {
4978 t.Errorf("Expected error %v, got: %v", test.wantErr, err)
4979 }
4980 if len(body) > 0 {
4981 t.Errorf("Expected empty body, got: %v", body)
4982 }
4983 if res.ContentLength != test.wantContentLength {
4984 t.Errorf("Expected content length %d, got: %d", test.wantContentLength, res.ContentLength)
4985 }
4986 })
4987 }
4988 }
4989
4990 func TestTransportCloseResponseBodyWhileRequestBodyHangs(t *testing.T) {
4991 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4992 w.WriteHeader(200)
4993 w.(http.Flusher).Flush()
4994 io.Copy(io.Discard, r.Body)
4995 })
4996
4997 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4998 defer tr.CloseIdleConnections()
4999
5000 pr, pw := net.Pipe()
5001 req, err := http.NewRequest("GET", ts.URL, pr)
5002 if err != nil {
5003 t.Fatal(err)
5004 }
5005 res, err := tr.RoundTrip(req)
5006 if err != nil {
5007 t.Fatal(err)
5008 }
5009
5010 res.Body.Close()
5011 pw.Close()
5012 }
5013
5014 func TestTransport300ResponseBody(t *testing.T) {
5015 reqc := make(chan struct{})
5016 body := []byte("response body")
5017 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5018 w.WriteHeader(300)
5019 w.(http.Flusher).Flush()
5020 <-reqc
5021 w.Write(body)
5022 })
5023
5024 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
5025 defer tr.CloseIdleConnections()
5026
5027 pr, pw := net.Pipe()
5028 req, err := http.NewRequest("GET", ts.URL, pr)
5029 if err != nil {
5030 t.Fatal(err)
5031 }
5032 res, err := tr.RoundTrip(req)
5033 if err != nil {
5034 t.Fatal(err)
5035 }
5036 close(reqc)
5037 got, err := io.ReadAll(res.Body)
5038 if err != nil {
5039 t.Fatalf("error reading response body: %v", err)
5040 }
5041 if !bytes.Equal(got, body) {
5042 t.Errorf("got response body %q, want %q", string(got), string(body))
5043 }
5044 res.Body.Close()
5045 pw.Close()
5046 }
5047
5048 func TestTransportWriteByteTimeout(t *testing.T) {
5049 ts := newTestServer(t,
5050 func(w http.ResponseWriter, r *http.Request) {},
5051 )
5052 tr := &Transport{
5053 TLSClientConfig: tlsConfigInsecure,
5054 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
5055 _, c := net.Pipe()
5056 return c, nil
5057 },
5058 WriteByteTimeout: 1 * time.Millisecond,
5059 }
5060 defer tr.CloseIdleConnections()
5061 c := &http.Client{Transport: tr}
5062
5063 _, err := c.Get(ts.URL)
5064 if !errors.Is(err, os.ErrDeadlineExceeded) {
5065 t.Fatalf("Get on unresponsive connection: got %q; want ErrDeadlineExceeded", err)
5066 }
5067 }
5068
5069 type slowWriteConn struct {
5070 net.Conn
5071 hasWriteDeadline bool
5072 }
5073
5074 func (c *slowWriteConn) SetWriteDeadline(t time.Time) error {
5075 c.hasWriteDeadline = !t.IsZero()
5076 return nil
5077 }
5078
5079 func (c *slowWriteConn) Write(b []byte) (n int, err error) {
5080 if c.hasWriteDeadline && len(b) > 1 {
5081 n, err = c.Conn.Write(b[:1])
5082 if err != nil {
5083 return n, err
5084 }
5085 return n, fmt.Errorf("slow write: %w", os.ErrDeadlineExceeded)
5086 }
5087 return c.Conn.Write(b)
5088 }
5089
5090 func TestTransportSlowWrites(t *testing.T) {
5091 ts := newTestServer(t,
5092 func(w http.ResponseWriter, r *http.Request) {},
5093 )
5094 tr := &Transport{
5095 TLSClientConfig: tlsConfigInsecure,
5096 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) {
5097 cfg.InsecureSkipVerify = true
5098 c, err := tls.Dial(network, addr, cfg)
5099 return &slowWriteConn{Conn: c}, err
5100 },
5101 WriteByteTimeout: 1 * time.Millisecond,
5102 }
5103 defer tr.CloseIdleConnections()
5104 c := &http.Client{Transport: tr}
5105
5106 const bodySize = 1 << 20
5107 resp, err := c.Post(ts.URL, "text/foo", io.LimitReader(neverEnding('A'), bodySize))
5108 if err != nil {
5109 t.Fatal(err)
5110 }
5111 resp.Body.Close()
5112 }
5113
5114 func TestTransportClosesConnAfterGoAwayNoStreams(t *testing.T) {
5115 testTransportClosesConnAfterGoAway(t, 0)
5116 }
5117 func TestTransportClosesConnAfterGoAwayLastStream(t *testing.T) {
5118 testTransportClosesConnAfterGoAway(t, 1)
5119 }
5120
5121
5122
5123
5124
5125
5126
5127 func testTransportClosesConnAfterGoAway(t *testing.T, lastStream uint32) {
5128 tc := newTestClientConn(t)
5129 tc.greet()
5130
5131 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
5132 rt := tc.roundTrip(req)
5133
5134 tc.wantFrameType(FrameHeaders)
5135 tc.writeGoAway(lastStream, ErrCodeNo, nil)
5136
5137 if lastStream > 0 {
5138
5139 tc.writeHeaders(HeadersFrameParam{
5140 StreamID: rt.streamID(),
5141 EndHeaders: true,
5142 EndStream: true,
5143 BlockFragment: tc.makeHeaderBlockFragment(
5144 ":status", "200",
5145 ),
5146 })
5147 }
5148
5149 tc.closeWrite()
5150 err := rt.err()
5151 if gotErr, wantErr := err != nil, lastStream == 0; gotErr != wantErr {
5152 t.Errorf("RoundTrip got error %v (want error: %v)", err, wantErr)
5153 }
5154 if !tc.isClosed() {
5155 t.Errorf("ClientConn did not close its net.Conn, expected it to")
5156 }
5157 }
5158
5159 type slowCloser struct {
5160 closing chan struct{}
5161 closed chan struct{}
5162 }
5163
5164 func (r *slowCloser) Read([]byte) (int, error) {
5165 return 0, io.EOF
5166 }
5167
5168 func (r *slowCloser) Close() error {
5169 close(r.closing)
5170 <-r.closed
5171 return nil
5172 }
5173
5174 func TestTransportSlowClose(t *testing.T) {
5175 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5176 })
5177
5178 client := ts.Client()
5179 body := &slowCloser{
5180 closing: make(chan struct{}),
5181 closed: make(chan struct{}),
5182 }
5183
5184 reqc := make(chan struct{})
5185 go func() {
5186 defer close(reqc)
5187 res, err := client.Post(ts.URL, "text/plain", body)
5188 if err != nil {
5189 t.Error(err)
5190 }
5191 res.Body.Close()
5192 }()
5193 defer func() {
5194 close(body.closed)
5195 <-reqc
5196 }()
5197
5198 <-body.closing
5199
5200 res, err := client.Get(ts.URL)
5201 if err != nil {
5202 t.Fatal(err)
5203 }
5204 res.Body.Close()
5205 }
5206
5207 func TestTransportDialTLSContext(t *testing.T) {
5208 blockCh := make(chan struct{})
5209 serverTLSConfigFunc := func(ts *httptest.Server) {
5210 ts.Config.TLSConfig = &tls.Config{
5211
5212
5213 ClientAuth: tls.RequestClientCert,
5214 }
5215 }
5216 ts := newTestServer(t,
5217 func(w http.ResponseWriter, r *http.Request) {},
5218 serverTLSConfigFunc,
5219 )
5220 tr := &Transport{
5221 TLSClientConfig: &tls.Config{
5222 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
5223
5224
5225 close(blockCh)
5226 <-cri.Context().Done()
5227 return nil, cri.Context().Err()
5228 },
5229 InsecureSkipVerify: true,
5230 },
5231 }
5232 defer tr.CloseIdleConnections()
5233 req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
5234 if err != nil {
5235 t.Fatal(err)
5236 }
5237 ctx, cancel := context.WithCancel(context.Background())
5238 defer cancel()
5239 req = req.WithContext(ctx)
5240 errCh := make(chan error)
5241 go func() {
5242 defer close(errCh)
5243 res, err := tr.RoundTrip(req)
5244 if err != nil {
5245 errCh <- err
5246 return
5247 }
5248 res.Body.Close()
5249 }()
5250
5251 <-blockCh
5252
5253 cancel()
5254
5255 err = <-errCh
5256 if err == nil {
5257 t.Fatal("cancelling context during client certificate fetch did not error as expected")
5258 return
5259 }
5260 if !errors.Is(err, context.Canceled) {
5261 t.Fatalf("unexpected error returned after cancellation: %v", err)
5262 }
5263 }
5264
5265
5266
5267
5268
5269 func TestDialRaceResumesDial(t *testing.T) {
5270 blockCh := make(chan struct{})
5271 serverTLSConfigFunc := func(ts *httptest.Server) {
5272 ts.Config.TLSConfig = &tls.Config{
5273
5274
5275 ClientAuth: tls.RequestClientCert,
5276 }
5277 }
5278 ts := newTestServer(t,
5279 func(w http.ResponseWriter, r *http.Request) {},
5280 serverTLSConfigFunc,
5281 )
5282 tr := &Transport{
5283 TLSClientConfig: &tls.Config{
5284 GetClientCertificate: func(cri *tls.CertificateRequestInfo) (*tls.Certificate, error) {
5285 select {
5286 case <-blockCh:
5287
5288 return &tls.Certificate{}, nil
5289 default:
5290 }
5291 close(blockCh)
5292 <-cri.Context().Done()
5293 return nil, cri.Context().Err()
5294 },
5295 InsecureSkipVerify: true,
5296 },
5297 }
5298 defer tr.CloseIdleConnections()
5299 req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
5300 if err != nil {
5301 t.Fatal(err)
5302 }
5303
5304 ctx1, cancel1 := context.WithCancel(context.Background())
5305 defer cancel1()
5306 req1 := req.WithContext(ctx1)
5307 ctx2, cancel2 := context.WithCancel(context.Background())
5308 defer cancel2()
5309 req2 := req.WithContext(ctx2)
5310 errCh := make(chan error)
5311 go func() {
5312 res, err := tr.RoundTrip(req1)
5313 if err != nil {
5314 errCh <- err
5315 return
5316 }
5317 res.Body.Close()
5318 }()
5319 successCh := make(chan struct{})
5320 go func() {
5321
5322
5323 <-blockCh
5324 res, err := tr.RoundTrip(req2)
5325 if err != nil {
5326 errCh <- err
5327 return
5328 }
5329 res.Body.Close()
5330
5331
5332 close(successCh)
5333 }()
5334
5335 <-blockCh
5336
5337 cancel1()
5338
5339 err = <-errCh
5340 if err == nil {
5341 t.Fatal("cancelling context during client certificate fetch did not error as expected")
5342 return
5343 }
5344 if !errors.Is(err, context.Canceled) {
5345 t.Fatalf("unexpected error returned after cancellation: %v", err)
5346 }
5347 select {
5348 case err := <-errCh:
5349 t.Fatalf("unexpected second error: %v", err)
5350 case <-successCh:
5351 }
5352 }
5353
5354 func TestTransportDataAfter1xxHeader(t *testing.T) {
5355
5356 log.SetOutput(io.Discard)
5357 defer log.SetOutput(os.Stderr)
5358
5359
5360 tc := newTestClientConn(t)
5361 tc.greet()
5362
5363 req, _ := http.NewRequest("GET", "https://dummy.tld/", nil)
5364 rt := tc.roundTrip(req)
5365
5366 tc.wantFrameType(FrameHeaders)
5367 tc.writeHeaders(HeadersFrameParam{
5368 StreamID: rt.streamID(),
5369 EndHeaders: true,
5370 EndStream: false,
5371 BlockFragment: tc.makeHeaderBlockFragment(
5372 ":status", "100",
5373 ),
5374 })
5375 tc.writeData(rt.streamID(), true, []byte{0})
5376 err := rt.err()
5377 if err, ok := err.(StreamError); !ok || err.Code != ErrCodeProtocol {
5378 t.Errorf("RoundTrip error: %v; want ErrCodeProtocol", err)
5379 }
5380 tc.wantFrameType(FrameRSTStream)
5381 }
5382
5383 func TestIssue66763Race(t *testing.T) {
5384 tr := &Transport{
5385 IdleConnTimeout: 1 * time.Nanosecond,
5386 AllowHTTP: true,
5387 }
5388 defer tr.CloseIdleConnections()
5389
5390 cli, srv := net.Pipe()
5391 donec := make(chan struct{})
5392 go func() {
5393
5394
5395
5396 tr.NewClientConn(cli)
5397 close(donec)
5398 }()
5399
5400
5401
5402 io.ReadAll(srv)
5403 srv.Close()
5404
5405 <-donec
5406 }
5407
5408
5409
5410 func TestIssue67671(t *testing.T) {
5411 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {})
5412 tr := &Transport{
5413 TLSClientConfig: tlsConfigInsecure,
5414 AllowHTTP: true,
5415 }
5416 defer tr.CloseIdleConnections()
5417 req, _ := http.NewRequest("GET", ts.URL, nil)
5418 req.Close = true
5419 for i := 0; i < 2; i++ {
5420 res, err := tr.RoundTrip(req)
5421 if err != nil {
5422 t.Fatal(err)
5423 }
5424 res.Body.Close()
5425 }
5426 }
5427
5428 func TestTransport1xxLimits(t *testing.T) {
5429 for _, test := range []struct {
5430 name string
5431 opt any
5432 ctxfn func(context.Context) context.Context
5433 hcount int
5434 limited bool
5435 }{{
5436 name: "default",
5437 hcount: 10,
5438 limited: false,
5439 }, {
5440 name: "MaxHeaderListSize",
5441 opt: func(tr *Transport) {
5442 tr.MaxHeaderListSize = 10000
5443 },
5444 hcount: 10,
5445 limited: true,
5446 }, {
5447 name: "MaxResponseHeaderBytes",
5448 opt: func(tr *http.Transport) {
5449 tr.MaxResponseHeaderBytes = 10000
5450 },
5451 hcount: 10,
5452 limited: true,
5453 }, {
5454 name: "limit by client trace",
5455 ctxfn: func(ctx context.Context) context.Context {
5456 count := 0
5457 return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
5458 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
5459 count++
5460 if count >= 10 {
5461 return errors.New("too many 1xx")
5462 }
5463 return nil
5464 },
5465 })
5466 },
5467 hcount: 10,
5468 limited: true,
5469 }, {
5470 name: "limit disabled by client trace",
5471 opt: func(tr *Transport) {
5472 tr.MaxHeaderListSize = 10000
5473 },
5474 ctxfn: func(ctx context.Context) context.Context {
5475 return httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
5476 Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
5477 return nil
5478 },
5479 })
5480 },
5481 hcount: 20,
5482 limited: false,
5483 }} {
5484 t.Run(test.name, func(t *testing.T) {
5485 tc := newTestClientConn(t, test.opt)
5486 tc.greet()
5487
5488 ctx := context.Background()
5489 if test.ctxfn != nil {
5490 ctx = test.ctxfn(ctx)
5491 }
5492 req, _ := http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil)
5493 rt := tc.roundTrip(req)
5494 tc.wantFrameType(FrameHeaders)
5495
5496 for i := 0; i < test.hcount; i++ {
5497 if fr, err := tc.fr.ReadFrame(); err != os.ErrDeadlineExceeded {
5498 t.Fatalf("after writing %v 1xx headers: read %v, %v; want idle", i, fr, err)
5499 }
5500 tc.writeHeaders(HeadersFrameParam{
5501 StreamID: rt.streamID(),
5502 EndHeaders: true,
5503 EndStream: false,
5504 BlockFragment: tc.makeHeaderBlockFragment(
5505 ":status", "103",
5506 "x-field", strings.Repeat("a", 1000),
5507 ),
5508 })
5509 }
5510 if test.limited {
5511 tc.wantFrameType(FrameRSTStream)
5512 } else {
5513 tc.wantIdle()
5514 }
5515 })
5516 }
5517 }
5518
5519 func TestTransportSendPingWithReset(t *testing.T) {
5520 tc := newTestClientConn(t, func(tr *Transport) {
5521 tr.StrictMaxConcurrentStreams = true
5522 })
5523
5524 const maxConcurrent = 3
5525 tc.greet(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5526
5527
5528 var rts []*testRoundTrip
5529 for i := 0; i < maxConcurrent+1; i++ {
5530 req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
5531 rt := tc.roundTrip(req)
5532 if i >= maxConcurrent {
5533 tc.wantIdle()
5534 continue
5535 }
5536 tc.wantFrameType(FrameHeaders)
5537 tc.writeHeaders(HeadersFrameParam{
5538 StreamID: rt.streamID(),
5539 EndHeaders: true,
5540 BlockFragment: tc.makeHeaderBlockFragment(
5541 ":status", "200",
5542 ),
5543 })
5544 rt.wantStatus(200)
5545 rts = append(rts, rt)
5546 }
5547
5548
5549 rts[0].response().Body.Close()
5550 tc.wantRSTStream(rts[0].streamID(), ErrCodeCancel)
5551 pf := readFrame[*PingFrame](t, tc)
5552 tc.wantIdle()
5553
5554
5555 rts[1].response().Body.Close()
5556 tc.wantRSTStream(rts[1].streamID(), ErrCodeCancel)
5557 tc.wantIdle()
5558
5559
5560
5561 tc.writePing(true, pf.Data)
5562 tc.wantFrameType(FrameHeaders)
5563 tc.wantIdle()
5564
5565
5566
5567 tc.writeData(rts[2].streamID(), false, []byte{0})
5568
5569
5570 rts[2].response().Body.Close()
5571 tc.wantRSTStream(rts[2].streamID(), ErrCodeCancel)
5572 tc.wantFrameType(FramePing)
5573 tc.wantIdle()
5574 }
5575
5576
5577
5578 func TestTransportSendNoMoreThanOnePingWithReset(t *testing.T) {
5579 tc := newTestClientConn(t)
5580 tc.greet()
5581
5582 makeAndResetRequest := func() {
5583 t.Helper()
5584 ctx, cancel := context.WithCancel(context.Background())
5585 req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
5586 rt := tc.roundTrip(req)
5587 tc.wantFrameType(FrameHeaders)
5588 cancel()
5589 tc.wantRSTStream(rt.streamID(), ErrCodeCancel)
5590 }
5591
5592
5593
5594 makeAndResetRequest()
5595 pf1 := readFrame[*PingFrame](t, tc)
5596
5597
5598
5599
5600
5601 makeAndResetRequest()
5602
5603
5604
5605 tc.writeHeaders(HeadersFrameParam{
5606 StreamID: 1,
5607 EndHeaders: true,
5608 EndStream: true,
5609 BlockFragment: tc.makeHeaderBlockFragment(
5610 ":status", "200",
5611 ),
5612 })
5613
5614
5615
5616
5617 makeAndResetRequest()
5618
5619
5620 tc.writePing(true, pf1.Data)
5621
5622
5623
5624
5625 makeAndResetRequest()
5626
5627
5628 tc.writeHeaders(HeadersFrameParam{
5629 StreamID: 3,
5630 EndHeaders: true,
5631 EndStream: true,
5632 BlockFragment: tc.makeHeaderBlockFragment(
5633 ":status", "200",
5634 ),
5635 })
5636
5637
5638
5639 makeAndResetRequest()
5640 tc.wantFrameType(FramePing)
5641 }
5642
5643 func TestTransportConnBecomesUnresponsive(t *testing.T) {
5644
5645
5646
5647 tt := newTestTransport(t)
5648
5649 const maxConcurrent = 3
5650
5651 t.Logf("first request opens a new connection and succeeds")
5652 req1 := must(http.NewRequest("GET", "https://dummy.tld/", nil))
5653 rt1 := tt.roundTrip(req1)
5654 tc1 := tt.getConn()
5655 tc1.wantFrameType(FrameSettings)
5656 tc1.wantFrameType(FrameWindowUpdate)
5657 hf1 := readFrame[*HeadersFrame](t, tc1)
5658 tc1.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5659 tc1.wantFrameType(FrameSettings)
5660 tc1.writeHeaders(HeadersFrameParam{
5661 StreamID: hf1.StreamID,
5662 EndHeaders: true,
5663 EndStream: true,
5664 BlockFragment: tc1.makeHeaderBlockFragment(
5665 ":status", "200",
5666 ),
5667 })
5668 rt1.wantStatus(200)
5669 rt1.response().Body.Close()
5670
5671
5672
5673
5674 for i := 0; i < maxConcurrent; i++ {
5675 t.Logf("request %v receives no response and is canceled", i)
5676 ctx, cancel := context.WithCancel(context.Background())
5677 req := must(http.NewRequestWithContext(ctx, "GET", "https://dummy.tld/", nil))
5678 tt.roundTrip(req)
5679 if tt.hasConn() {
5680 t.Fatalf("new connection created; expect existing conn to be reused")
5681 }
5682 tc1.wantFrameType(FrameHeaders)
5683 cancel()
5684 tc1.wantFrameType(FrameRSTStream)
5685 if i == 0 {
5686 tc1.wantFrameType(FramePing)
5687 }
5688 tc1.wantIdle()
5689 }
5690
5691
5692
5693 req2 := must(http.NewRequest("GET", "https://dummy.tld/", nil))
5694 rt2 := tt.roundTrip(req2)
5695 tc2 := tt.getConn()
5696 tc2.wantFrameType(FrameSettings)
5697 tc2.wantFrameType(FrameWindowUpdate)
5698 hf := readFrame[*HeadersFrame](t, tc2)
5699 tc2.writeSettings(Setting{SettingMaxConcurrentStreams, maxConcurrent})
5700 tc2.wantFrameType(FrameSettings)
5701 tc2.writeHeaders(HeadersFrameParam{
5702 StreamID: hf.StreamID,
5703 EndHeaders: true,
5704 EndStream: true,
5705 BlockFragment: tc2.makeHeaderBlockFragment(
5706 ":status", "200",
5707 ),
5708 })
5709 rt2.wantStatus(200)
5710 rt2.response().Body.Close()
5711 }
5712
5713
5714 func TestTransportTLSNextProtoConnOK(t *testing.T) {
5715 t1 := &http.Transport{}
5716 t2, _ := ConfigureTransports(t1)
5717 tt := newTestTransport(t, t2)
5718
5719
5720 cli, _ := synctestNetPipe(tt.group)
5721 cliTLS := tls.Client(cli, tlsConfigInsecure)
5722 go func() {
5723 tt.group.Join()
5724 t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
5725 }()
5726 tt.sync()
5727 tc := tt.getConn()
5728 tc.greet()
5729
5730
5731
5732 req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
5733 rt := tt.roundTrip(req)
5734 tc.wantHeaders(wantHeader{
5735 streamID: 1,
5736 endStream: true,
5737 header: http.Header{
5738 ":authority": []string{"dummy.tld"},
5739 ":method": []string{"GET"},
5740 ":path": []string{"/"},
5741 },
5742 })
5743 tc.writeHeaders(HeadersFrameParam{
5744 StreamID: 1,
5745 EndHeaders: true,
5746 EndStream: true,
5747 BlockFragment: tc.makeHeaderBlockFragment(
5748 ":status", "200",
5749 ),
5750 })
5751 rt.wantStatus(200)
5752 rt.wantBody(nil)
5753 }
5754
5755
5756 func TestTransportTLSNextProtoConnImmediateFailureUsed(t *testing.T) {
5757 t1 := &http.Transport{}
5758 t2, _ := ConfigureTransports(t1)
5759 tt := newTestTransport(t, t2)
5760
5761
5762 cli, _ := synctestNetPipe(tt.group)
5763 cliTLS := tls.Client(cli, tlsConfigInsecure)
5764 go func() {
5765 tt.group.Join()
5766 t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
5767 }()
5768 tt.sync()
5769 tc := tt.getConn()
5770
5771
5772 tc.closeWrite()
5773
5774
5775
5776
5777 req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
5778 rt := tt.roundTrip(req)
5779 if err := rt.err(); err == nil || errors.Is(err, ErrNoCachedConn) {
5780 t.Fatalf("RoundTrip with broken conn: got %v, want an error other than ErrNoCachedConn", err)
5781 }
5782
5783
5784
5785
5786 rt = tt.roundTrip(req)
5787 if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
5788 t.Fatalf("RoundTrip after broken conn is used: got %v, want ErrNoCachedConn", err)
5789 }
5790 }
5791
5792
5793
5794 func TestTransportTLSNextProtoConnImmediateFailureUnused(t *testing.T) {
5795 t1 := &http.Transport{}
5796 t2, _ := ConfigureTransports(t1)
5797 tt := newTestTransport(t, t2)
5798
5799
5800 cli, _ := synctestNetPipe(tt.group)
5801 cliTLS := tls.Client(cli, tlsConfigInsecure)
5802 go func() {
5803 tt.group.Join()
5804 t1.TLSNextProto["h2"]("dummy.tld", cliTLS)
5805 }()
5806 tt.sync()
5807 tc := tt.getConn()
5808
5809
5810 tc.closeWrite()
5811
5812
5813
5814 tc.advance(10 * time.Second)
5815
5816
5817
5818
5819 req := must(http.NewRequest("GET", "https://dummy.tld/", nil))
5820 rt := tt.roundTrip(req)
5821 if err := rt.err(); !errors.Is(err, ErrNoCachedConn) {
5822 t.Fatalf("RoundTrip after broken conn expires: got %v, want ErrNoCachedConn", err)
5823 }
5824 }
5825
5826 func TestExtendedConnectClientWithServerSupport(t *testing.T) {
5827 disableExtendedConnectProtocol = false
5828 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5829 if r.Header.Get(":protocol") != "extended-connect" {
5830 t.Fatalf("unexpected :protocol header received")
5831 }
5832 t.Log(io.Copy(w, r.Body))
5833 })
5834 tr := &Transport{
5835 TLSClientConfig: tlsConfigInsecure,
5836 AllowHTTP: true,
5837 }
5838 defer tr.CloseIdleConnections()
5839 pr, pw := io.Pipe()
5840 pwDone := make(chan struct{})
5841 req, _ := http.NewRequest("CONNECT", ts.URL, pr)
5842 req.Header.Set(":protocol", "extended-connect")
5843 go func() {
5844 pw.Write([]byte("hello, extended connect"))
5845 pw.Close()
5846 close(pwDone)
5847 }()
5848
5849 res, err := tr.RoundTrip(req)
5850 if err != nil {
5851 t.Fatal(err)
5852 }
5853 body, err := io.ReadAll(res.Body)
5854 if err != nil {
5855 t.Fatal(err)
5856 }
5857 if !bytes.Equal(body, []byte("hello, extended connect")) {
5858 t.Fatal("unexpected body received")
5859 }
5860 }
5861
5862 func TestExtendedConnectClientWithoutServerSupport(t *testing.T) {
5863 disableExtendedConnectProtocol = true
5864 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
5865 io.Copy(w, r.Body)
5866 })
5867 tr := &Transport{
5868 TLSClientConfig: tlsConfigInsecure,
5869 AllowHTTP: true,
5870 }
5871 defer tr.CloseIdleConnections()
5872 pr, pw := io.Pipe()
5873 pwDone := make(chan struct{})
5874 req, _ := http.NewRequest("CONNECT", ts.URL, pr)
5875 req.Header.Set(":protocol", "extended-connect")
5876 go func() {
5877 pw.Write([]byte("hello, extended connect"))
5878 pw.Close()
5879 close(pwDone)
5880 }()
5881
5882 _, err := tr.RoundTrip(req)
5883 if !errors.Is(err, errExtendedConnectNotSupported) {
5884 t.Fatalf("expected error errExtendedConnectNotSupported, got: %v", err)
5885 }
5886 }
5887
View as plain text