1
2
3
4
5 package http2
6
7 import (
8 "bytes"
9 "compress/gzip"
10 "compress/zlib"
11 "context"
12 "crypto/tls"
13 "errors"
14 "flag"
15 "fmt"
16 "io"
17 "log"
18 "math"
19 "net"
20 "net/http"
21 "net/http/httptest"
22 "os"
23 "reflect"
24 "runtime"
25 "strconv"
26 "strings"
27 "sync"
28 "testing"
29 "time"
30
31 "golang.org/x/net/http2/hpack"
32 )
33
34 var stderrVerbose = flag.Bool("stderr_verbose", false, "Mirror verbosity to stderr, unbuffered")
35
36 func stderrv() io.Writer {
37 if *stderrVerbose {
38 return os.Stderr
39 }
40
41 return io.Discard
42 }
43
44 type safeBuffer struct {
45 b bytes.Buffer
46 m sync.Mutex
47 }
48
49 func (sb *safeBuffer) Write(d []byte) (int, error) {
50 sb.m.Lock()
51 defer sb.m.Unlock()
52 return sb.b.Write(d)
53 }
54
55 func (sb *safeBuffer) Bytes() []byte {
56 sb.m.Lock()
57 defer sb.m.Unlock()
58 return sb.b.Bytes()
59 }
60
61 func (sb *safeBuffer) Len() int {
62 sb.m.Lock()
63 defer sb.m.Unlock()
64 return sb.b.Len()
65 }
66
67 type serverTester struct {
68 cc net.Conn
69 t testing.TB
70 group *synctestGroup
71 h1server *http.Server
72 h2server *Server
73 serverLogBuf safeBuffer
74 logFilter []string
75 scMu sync.Mutex
76 sc *serverConn
77 testConnFramer
78
79
80
81
82
83 frameReadLogMu sync.Mutex
84 frameReadLogBuf bytes.Buffer
85 frameWriteLogMu sync.Mutex
86 frameWriteLogBuf bytes.Buffer
87
88
89 headerBuf bytes.Buffer
90 hpackEnc *hpack.Encoder
91 }
92
93 func init() {
94 testHookOnPanicMu = new(sync.Mutex)
95 goAwayTimeout = 25 * time.Millisecond
96 }
97
98 func resetHooks() {
99 testHookOnPanicMu.Lock()
100 testHookOnPanic = nil
101 testHookOnPanicMu.Unlock()
102 }
103
104 func newTestServer(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *httptest.Server {
105 ts := httptest.NewUnstartedServer(handler)
106 ts.EnableHTTP2 = true
107 ts.Config.ErrorLog = log.New(twriter{t: t}, "", log.LstdFlags)
108 h2server := new(Server)
109 for _, opt := range opts {
110 switch v := opt.(type) {
111 case func(*httptest.Server):
112 v(ts)
113 case func(*http.Server):
114 v(ts.Config)
115 case func(*Server):
116 v(h2server)
117 default:
118 t.Fatalf("unknown newTestServer option type %T", v)
119 }
120 }
121 ConfigureServer(ts.Config, h2server)
122
123
124
125 ts.TLS = ts.Config.TLSConfig
126
127
128
129
130
131 ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
132
133 ts.StartTLS()
134 t.Cleanup(func() {
135 ts.CloseClientConnections()
136 ts.Close()
137 })
138
139 return ts
140 }
141
142 type serverTesterOpt string
143
144 var optFramerReuseFrames = serverTesterOpt("frame_reuse_frames")
145
146 var optQuiet = func(server *http.Server) {
147 server.ErrorLog = log.New(io.Discard, "", 0)
148 }
149
150 func newServerTester(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
151 t.Helper()
152 g := newSynctest(time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC))
153 t.Cleanup(func() {
154 g.Close(t)
155 })
156
157 h1server := &http.Server{}
158 h2server := &Server{
159 group: g,
160 }
161 tlsState := tls.ConnectionState{
162 Version: tls.VersionTLS13,
163 ServerName: "go.dev",
164 CipherSuite: tls.TLS_AES_128_GCM_SHA256,
165 }
166 for _, opt := range opts {
167 switch v := opt.(type) {
168 case func(*Server):
169 v(h2server)
170 case func(*http.Server):
171 v(h1server)
172 case func(*tls.ConnectionState):
173 v(&tlsState)
174 default:
175 t.Fatalf("unknown newServerTester option type %T", v)
176 }
177 }
178 ConfigureServer(h1server, h2server)
179
180 cli, srv := synctestNetPipe(g)
181 cli.SetReadDeadline(g.Now())
182 cli.autoWait = true
183
184 st := &serverTester{
185 t: t,
186 cc: cli,
187 group: g,
188 h1server: h1server,
189 h2server: h2server,
190 }
191 st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
192 if h1server.ErrorLog == nil {
193 h1server.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
194 }
195
196 t.Cleanup(func() {
197 st.Close()
198 g.AdvanceTime(goAwayTimeout)
199 })
200
201 connc := make(chan *serverConn)
202 go func() {
203 g.Join()
204 h2server.serveConn(&netConnWithConnectionState{
205 Conn: srv,
206 state: tlsState,
207 }, &ServeConnOpts{
208 Handler: handler,
209 BaseConfig: h1server,
210 }, func(sc *serverConn) {
211 connc <- sc
212 })
213 }()
214 st.sc = <-connc
215
216 st.fr = NewFramer(st.cc, st.cc)
217 st.testConnFramer = testConnFramer{
218 t: t,
219 fr: NewFramer(st.cc, st.cc),
220 dec: hpack.NewDecoder(initialHeaderTableSize, nil),
221 }
222 g.Wait()
223 return st
224 }
225
226 type netConnWithConnectionState struct {
227 net.Conn
228 state tls.ConnectionState
229 }
230
231 func (c *netConnWithConnectionState) ConnectionState() tls.ConnectionState {
232 return c.state
233 }
234
235
236
237
238
239 func newServerTesterWithRealConn(t testing.TB, handler http.HandlerFunc, opts ...interface{}) *serverTester {
240 resetHooks()
241
242 ts := httptest.NewUnstartedServer(handler)
243 t.Cleanup(ts.Close)
244
245 tlsConfig := &tls.Config{
246 InsecureSkipVerify: true,
247 NextProtos: []string{NextProtoTLS},
248 }
249
250 var framerReuseFrames bool
251 h2server := new(Server)
252 for _, opt := range opts {
253 switch v := opt.(type) {
254 case func(*tls.Config):
255 v(tlsConfig)
256 case func(*httptest.Server):
257 v(ts)
258 case func(*http.Server):
259 v(ts.Config)
260 case func(*Server):
261 v(h2server)
262 case serverTesterOpt:
263 switch v {
264 case optFramerReuseFrames:
265 framerReuseFrames = true
266 }
267 case func(net.Conn, http.ConnState):
268 ts.Config.ConnState = v
269 default:
270 t.Fatalf("unknown newServerTester option type %T", v)
271 }
272 }
273
274 ConfigureServer(ts.Config, h2server)
275
276
277
278
279
280 ts.Config.TLSConfig.MinVersion = tls.VersionTLS10
281
282 st := &serverTester{
283 t: t,
284 }
285 st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
286
287 ts.TLS = ts.Config.TLSConfig
288 if ts.Config.ErrorLog == nil {
289 ts.Config.ErrorLog = log.New(io.MultiWriter(stderrv(), twriter{t: t, st: st}, &st.serverLogBuf), "", log.LstdFlags)
290 }
291 ts.StartTLS()
292
293 if VerboseLogs {
294 t.Logf("Running test server at: %s", ts.URL)
295 }
296 testHookGetServerConn = func(v *serverConn) {
297 st.scMu.Lock()
298 defer st.scMu.Unlock()
299 st.sc = v
300 }
301 log.SetOutput(io.MultiWriter(stderrv(), twriter{t: t, st: st}))
302 cc, err := tls.Dial("tcp", ts.Listener.Addr().String(), tlsConfig)
303 if err != nil {
304 t.Fatal(err)
305 }
306 st.cc = cc
307 st.testConnFramer = testConnFramer{
308 t: t,
309 fr: NewFramer(st.cc, st.cc),
310 dec: hpack.NewDecoder(initialHeaderTableSize, nil),
311 }
312 if framerReuseFrames {
313 st.fr.SetReuseFrames()
314 }
315 if !logFrameReads && !logFrameWrites {
316 st.fr.debugReadLoggerf = func(m string, v ...interface{}) {
317 m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
318 st.frameReadLogMu.Lock()
319 fmt.Fprintf(&st.frameReadLogBuf, m, v...)
320 st.frameReadLogMu.Unlock()
321 }
322 st.fr.debugWriteLoggerf = func(m string, v ...interface{}) {
323 m = time.Now().Format("2006-01-02 15:04:05.999999999 ") + strings.TrimPrefix(m, "http2: ") + "\n"
324 st.frameWriteLogMu.Lock()
325 fmt.Fprintf(&st.frameWriteLogBuf, m, v...)
326 st.frameWriteLogMu.Unlock()
327 }
328 st.fr.logReads = true
329 st.fr.logWrites = true
330 }
331 return st
332 }
333
334
335 func (st *serverTester) sync() {
336 if st.group != nil {
337 st.group.Wait()
338 }
339 }
340
341
342 func (st *serverTester) advance(d time.Duration) {
343 st.group.AdvanceTime(d)
344 }
345
346 func (st *serverTester) authority() string {
347 return "dummy.tld"
348 }
349
350 func (st *serverTester) closeConn() {
351 st.scMu.Lock()
352 defer st.scMu.Unlock()
353 st.sc.conn.Close()
354 }
355
356 func (st *serverTester) addLogFilter(phrase string) {
357 st.logFilter = append(st.logFilter, phrase)
358 }
359
360 func (st *serverTester) stream(id uint32) *stream {
361 ch := make(chan *stream, 1)
362 st.sc.serveMsgCh <- func(int) {
363 ch <- st.sc.streams[id]
364 }
365 return <-ch
366 }
367
368 func (st *serverTester) streamState(id uint32) streamState {
369 ch := make(chan streamState, 1)
370 st.sc.serveMsgCh <- func(int) {
371 state, _ := st.sc.state(id)
372 ch <- state
373 }
374 return <-ch
375 }
376
377
378 func (st *serverTester) loopNum() int {
379 lastc := make(chan int, 1)
380 st.sc.serveMsgCh <- func(loopNum int) {
381 lastc <- loopNum
382 }
383 return <-lastc
384 }
385
386
387
388
389 func (st *serverTester) awaitIdle() {
390 remain := 50
391 last := st.loopNum()
392 for remain > 0 {
393 n := st.loopNum()
394 if n == last+1 {
395 remain--
396 } else {
397 remain = 50
398 }
399 last = n
400 }
401 }
402
403 func (st *serverTester) Close() {
404 if st.t.Failed() {
405 st.frameReadLogMu.Lock()
406 if st.frameReadLogBuf.Len() > 0 {
407 st.t.Logf("Framer read log:\n%s", st.frameReadLogBuf.String())
408 }
409 st.frameReadLogMu.Unlock()
410
411 st.frameWriteLogMu.Lock()
412 if st.frameWriteLogBuf.Len() > 0 {
413 st.t.Logf("Framer write log:\n%s", st.frameWriteLogBuf.String())
414 }
415 st.frameWriteLogMu.Unlock()
416
417
418
419
420
421 if st.cc != nil {
422 st.cc.Close()
423 }
424 }
425 if st.cc != nil {
426 st.cc.Close()
427 }
428 log.SetOutput(os.Stderr)
429 }
430
431
432
433 func (st *serverTester) greet() {
434 st.t.Helper()
435 st.greetAndCheckSettings(func(Setting) error { return nil })
436 }
437
438 func (st *serverTester) greetAndCheckSettings(checkSetting func(s Setting) error) {
439 st.t.Helper()
440 st.writePreface()
441 st.writeSettings()
442 st.sync()
443 readFrame[*SettingsFrame](st.t, st).ForeachSetting(checkSetting)
444 st.writeSettingsAck()
445
446
447 var gotSettingsAck bool
448 var gotWindowUpdate bool
449
450 for i := 0; i < 2; i++ {
451 f := st.readFrame()
452 if f == nil {
453 st.t.Fatal("wanted a settings ACK and window update, got none")
454 }
455 switch f := f.(type) {
456 case *SettingsFrame:
457 if !f.Header().Flags.Has(FlagSettingsAck) {
458 st.t.Fatal("Settings Frame didn't have ACK set")
459 }
460 gotSettingsAck = true
461
462 case *WindowUpdateFrame:
463 if f.FrameHeader.StreamID != 0 {
464 st.t.Fatalf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
465 }
466 conf := configFromServer(st.sc.hs, st.sc.srv)
467 incr := uint32(conf.MaxUploadBufferPerConnection - initialWindowSize)
468 if f.Increment != incr {
469 st.t.Fatalf("WindowUpdate increment = %d; want %d", f.Increment, incr)
470 }
471 gotWindowUpdate = true
472
473 default:
474 st.t.Fatalf("Wanting a settings ACK or window update, received a %T", f)
475 }
476 }
477
478 if !gotSettingsAck {
479 st.t.Fatalf("Didn't get a settings ACK")
480 }
481 if !gotWindowUpdate {
482 st.t.Fatalf("Didn't get a window update")
483 }
484 }
485
486 func (st *serverTester) writePreface() {
487 n, err := st.cc.Write(clientPreface)
488 if err != nil {
489 st.t.Fatalf("Error writing client preface: %v", err)
490 }
491 if n != len(clientPreface) {
492 st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(clientPreface))
493 }
494 }
495
496 func (st *serverTester) encodeHeaderField(k, v string) {
497 err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
498 if err != nil {
499 st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
500 }
501 }
502
503
504
505 func (st *serverTester) encodeHeaderRaw(headers ...string) []byte {
506 if len(headers)%2 == 1 {
507 panic("odd number of kv args")
508 }
509 st.headerBuf.Reset()
510 for len(headers) > 0 {
511 k, v := headers[0], headers[1]
512 st.encodeHeaderField(k, v)
513 headers = headers[2:]
514 }
515 return st.headerBuf.Bytes()
516 }
517
518
519
520
521
522
523 func (st *serverTester) encodeHeader(headers ...string) []byte {
524 if len(headers)%2 == 1 {
525 panic("odd number of kv args")
526 }
527
528 st.headerBuf.Reset()
529 defaultAuthority := st.authority()
530
531 if len(headers) == 0 {
532
533
534 st.encodeHeaderField(":method", "GET")
535 st.encodeHeaderField(":scheme", "https")
536 st.encodeHeaderField(":authority", defaultAuthority)
537 st.encodeHeaderField(":path", "/")
538 return st.headerBuf.Bytes()
539 }
540
541 if len(headers) == 2 && headers[0] == ":method" {
542
543 st.encodeHeaderField(":method", headers[1])
544 st.encodeHeaderField(":scheme", "https")
545 st.encodeHeaderField(":authority", defaultAuthority)
546 st.encodeHeaderField(":path", "/")
547 return st.headerBuf.Bytes()
548 }
549
550 pseudoCount := map[string]int{}
551 keys := []string{":method", ":scheme", ":authority", ":path"}
552 vals := map[string][]string{
553 ":method": {"GET"},
554 ":scheme": {"https"},
555 ":authority": {defaultAuthority},
556 ":path": {"/"},
557 }
558 for len(headers) > 0 {
559 k, v := headers[0], headers[1]
560 headers = headers[2:]
561 if _, ok := vals[k]; !ok {
562 keys = append(keys, k)
563 }
564 if strings.HasPrefix(k, ":") {
565 pseudoCount[k]++
566 if pseudoCount[k] == 1 {
567 vals[k] = []string{v}
568 } else {
569
570 vals[k] = append(vals[k], v)
571 }
572 } else {
573 vals[k] = append(vals[k], v)
574 }
575 }
576 for _, k := range keys {
577 for _, v := range vals[k] {
578 st.encodeHeaderField(k, v)
579 }
580 }
581 return st.headerBuf.Bytes()
582 }
583
584
585 func (st *serverTester) bodylessReq1(headers ...string) {
586 st.writeHeaders(HeadersFrameParam{
587 StreamID: 1,
588 BlockFragment: st.encodeHeader(headers...),
589 EndStream: true,
590 EndHeaders: true,
591 })
592 }
593
594 func (st *serverTester) wantFlowControlConsumed(streamID, consumed int32) {
595 conf := configFromServer(st.sc.hs, st.sc.srv)
596 var initial int32
597 if streamID == 0 {
598 initial = conf.MaxUploadBufferPerConnection
599 } else {
600 initial = conf.MaxUploadBufferPerStream
601 }
602 donec := make(chan struct{})
603 st.sc.sendServeMsg(func(sc *serverConn) {
604 defer close(donec)
605 var avail int32
606 if streamID == 0 {
607 avail = sc.inflow.avail + sc.inflow.unsent
608 } else {
609 }
610 if got, want := initial-avail, consumed; got != want {
611 st.t.Errorf("stream %v flow control consumed: %v, want %v", streamID, got, want)
612 }
613 })
614 <-donec
615 }
616
617 func TestServer(t *testing.T) {
618 gotReq := make(chan bool, 1)
619 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
620 w.Header().Set("Foo", "Bar")
621 gotReq <- true
622 })
623 defer st.Close()
624
625 st.greet()
626 st.writeHeaders(HeadersFrameParam{
627 StreamID: 1,
628 BlockFragment: st.encodeHeader(),
629 EndStream: true,
630 EndHeaders: true,
631 })
632
633 <-gotReq
634 }
635
636 func TestServer_Request_Get(t *testing.T) {
637 testServerRequest(t, func(st *serverTester) {
638 st.writeHeaders(HeadersFrameParam{
639 StreamID: 1,
640 BlockFragment: st.encodeHeader("foo-bar", "some-value"),
641 EndStream: true,
642 EndHeaders: true,
643 })
644 }, func(r *http.Request) {
645 if r.Method != "GET" {
646 t.Errorf("Method = %q; want GET", r.Method)
647 }
648 if r.URL.Path != "/" {
649 t.Errorf("URL.Path = %q; want /", r.URL.Path)
650 }
651 if r.ContentLength != 0 {
652 t.Errorf("ContentLength = %v; want 0", r.ContentLength)
653 }
654 if r.Close {
655 t.Error("Close = true; want false")
656 }
657 if !strings.Contains(r.RemoteAddr, ":") {
658 t.Errorf("RemoteAddr = %q; want something with a colon", r.RemoteAddr)
659 }
660 if r.Proto != "HTTP/2.0" || r.ProtoMajor != 2 || r.ProtoMinor != 0 {
661 t.Errorf("Proto = %q Major=%v,Minor=%v; want HTTP/2.0", r.Proto, r.ProtoMajor, r.ProtoMinor)
662 }
663 wantHeader := http.Header{
664 "Foo-Bar": []string{"some-value"},
665 }
666 if !reflect.DeepEqual(r.Header, wantHeader) {
667 t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
668 }
669 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
670 t.Errorf("Read = %d, %v; want 0, EOF", n, err)
671 }
672 })
673 }
674
675 func TestServer_Request_Get_PathSlashes(t *testing.T) {
676 testServerRequest(t, func(st *serverTester) {
677 st.writeHeaders(HeadersFrameParam{
678 StreamID: 1,
679 BlockFragment: st.encodeHeader(":path", "/%2f/"),
680 EndStream: true,
681 EndHeaders: true,
682 })
683 }, func(r *http.Request) {
684 if r.RequestURI != "/%2f/" {
685 t.Errorf("RequestURI = %q; want /%%2f/", r.RequestURI)
686 }
687 if r.URL.Path != "///" {
688 t.Errorf("URL.Path = %q; want ///", r.URL.Path)
689 }
690 })
691 }
692
693
694
695
696
697 func TestServer_Request_Post_NoContentLength_EndStream(t *testing.T) {
698 testServerRequest(t, func(st *serverTester) {
699 st.writeHeaders(HeadersFrameParam{
700 StreamID: 1,
701 BlockFragment: st.encodeHeader(":method", "POST"),
702 EndStream: true,
703 EndHeaders: true,
704 })
705 }, func(r *http.Request) {
706 if r.Method != "POST" {
707 t.Errorf("Method = %q; want POST", r.Method)
708 }
709 if r.ContentLength != 0 {
710 t.Errorf("ContentLength = %v; want 0", r.ContentLength)
711 }
712 if n, err := r.Body.Read([]byte(" ")); err != io.EOF || n != 0 {
713 t.Errorf("Read = %d, %v; want 0, EOF", n, err)
714 }
715 })
716 }
717
718 func TestServer_Request_Post_Body_ImmediateEOF(t *testing.T) {
719 testBodyContents(t, -1, "", func(st *serverTester) {
720 st.writeHeaders(HeadersFrameParam{
721 StreamID: 1,
722 BlockFragment: st.encodeHeader(":method", "POST"),
723 EndStream: false,
724 EndHeaders: true,
725 })
726 st.writeData(1, true, nil)
727 })
728 }
729
730 func TestServer_Request_Post_Body_OneData(t *testing.T) {
731 const content = "Some content"
732 testBodyContents(t, -1, content, func(st *serverTester) {
733 st.writeHeaders(HeadersFrameParam{
734 StreamID: 1,
735 BlockFragment: st.encodeHeader(":method", "POST"),
736 EndStream: false,
737 EndHeaders: true,
738 })
739 st.writeData(1, true, []byte(content))
740 })
741 }
742
743 func TestServer_Request_Post_Body_TwoData(t *testing.T) {
744 const content = "Some content"
745 testBodyContents(t, -1, content, func(st *serverTester) {
746 st.writeHeaders(HeadersFrameParam{
747 StreamID: 1,
748 BlockFragment: st.encodeHeader(":method", "POST"),
749 EndStream: false,
750 EndHeaders: true,
751 })
752 st.writeData(1, false, []byte(content[:5]))
753 st.writeData(1, true, []byte(content[5:]))
754 })
755 }
756
757 func TestServer_Request_Post_Body_ContentLength_Correct(t *testing.T) {
758 const content = "Some content"
759 testBodyContents(t, int64(len(content)), content, func(st *serverTester) {
760 st.writeHeaders(HeadersFrameParam{
761 StreamID: 1,
762 BlockFragment: st.encodeHeader(
763 ":method", "POST",
764 "content-length", strconv.Itoa(len(content)),
765 ),
766 EndStream: false,
767 EndHeaders: true,
768 })
769 st.writeData(1, true, []byte(content))
770 })
771 }
772
773 func TestServer_Request_Post_Body_ContentLength_TooLarge(t *testing.T) {
774 testBodyContentsFail(t, 3, "request declared a Content-Length of 3 but only wrote 2 bytes",
775 func(st *serverTester) {
776 st.writeHeaders(HeadersFrameParam{
777 StreamID: 1,
778 BlockFragment: st.encodeHeader(
779 ":method", "POST",
780 "content-length", "3",
781 ),
782 EndStream: false,
783 EndHeaders: true,
784 })
785 st.writeData(1, true, []byte("12"))
786 })
787 }
788
789 func TestServer_Request_Post_Body_ContentLength_TooSmall(t *testing.T) {
790 testBodyContentsFail(t, 4, "sender tried to send more than declared Content-Length of 4 bytes",
791 func(st *serverTester) {
792 st.writeHeaders(HeadersFrameParam{
793 StreamID: 1,
794 BlockFragment: st.encodeHeader(
795 ":method", "POST",
796 "content-length", "4",
797 ),
798 EndStream: false,
799 EndHeaders: true,
800 })
801 st.writeData(1, true, []byte("12345"))
802
803
804 st.wantRSTStream(1, ErrCodeProtocol)
805 st.wantFlowControlConsumed(0, 0)
806 })
807 }
808
809 func testBodyContents(t *testing.T, wantContentLength int64, wantBody string, write func(st *serverTester)) {
810 testServerRequest(t, write, func(r *http.Request) {
811 if r.Method != "POST" {
812 t.Errorf("Method = %q; want POST", r.Method)
813 }
814 if r.ContentLength != wantContentLength {
815 t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
816 }
817 all, err := io.ReadAll(r.Body)
818 if err != nil {
819 t.Fatal(err)
820 }
821 if string(all) != wantBody {
822 t.Errorf("Read = %q; want %q", all, wantBody)
823 }
824 if err := r.Body.Close(); err != nil {
825 t.Fatalf("Close: %v", err)
826 }
827 })
828 }
829
830 func testBodyContentsFail(t *testing.T, wantContentLength int64, wantReadError string, write func(st *serverTester)) {
831 testServerRequest(t, write, func(r *http.Request) {
832 if r.Method != "POST" {
833 t.Errorf("Method = %q; want POST", r.Method)
834 }
835 if r.ContentLength != wantContentLength {
836 t.Errorf("ContentLength = %v; want %d", r.ContentLength, wantContentLength)
837 }
838 all, err := io.ReadAll(r.Body)
839 if err == nil {
840 t.Fatalf("expected an error (%q) reading from the body. Successfully read %q instead.",
841 wantReadError, all)
842 }
843 if !strings.Contains(err.Error(), wantReadError) {
844 t.Fatalf("Body.Read = %v; want substring %q", err, wantReadError)
845 }
846 if err := r.Body.Close(); err != nil {
847 t.Fatalf("Close: %v", err)
848 }
849 })
850 }
851
852
853 func TestServer_Request_Get_Host(t *testing.T) {
854 const host = "example.com"
855 testServerRequest(t, func(st *serverTester) {
856 st.writeHeaders(HeadersFrameParam{
857 StreamID: 1,
858 BlockFragment: st.encodeHeader(":authority", "", "host", host),
859 EndStream: true,
860 EndHeaders: true,
861 })
862 }, func(r *http.Request) {
863 if r.Host != host {
864 t.Errorf("Host = %q; want %q", r.Host, host)
865 }
866 })
867 }
868
869
870 func TestServer_Request_Get_Authority(t *testing.T) {
871 const host = "example.com"
872 testServerRequest(t, func(st *serverTester) {
873 st.writeHeaders(HeadersFrameParam{
874 StreamID: 1,
875 BlockFragment: st.encodeHeader(":authority", host),
876 EndStream: true,
877 EndHeaders: true,
878 })
879 }, func(r *http.Request) {
880 if r.Host != host {
881 t.Errorf("Host = %q; want %q", r.Host, host)
882 }
883 })
884 }
885
886 func TestServer_Request_WithContinuation(t *testing.T) {
887 wantHeader := http.Header{
888 "Foo-One": []string{"value-one"},
889 "Foo-Two": []string{"value-two"},
890 "Foo-Three": []string{"value-three"},
891 }
892 testServerRequest(t, func(st *serverTester) {
893 fullHeaders := st.encodeHeader(
894 "foo-one", "value-one",
895 "foo-two", "value-two",
896 "foo-three", "value-three",
897 )
898 remain := fullHeaders
899 chunks := 0
900 for len(remain) > 0 {
901 const maxChunkSize = 5
902 chunk := remain
903 if len(chunk) > maxChunkSize {
904 chunk = chunk[:maxChunkSize]
905 }
906 remain = remain[len(chunk):]
907
908 if chunks == 0 {
909 st.writeHeaders(HeadersFrameParam{
910 StreamID: 1,
911 BlockFragment: chunk,
912 EndStream: true,
913 EndHeaders: false,
914 })
915 } else {
916 err := st.fr.WriteContinuation(1, len(remain) == 0, chunk)
917 if err != nil {
918 t.Fatal(err)
919 }
920 }
921 chunks++
922 }
923 if chunks < 2 {
924 t.Fatal("too few chunks")
925 }
926 }, func(r *http.Request) {
927 if !reflect.DeepEqual(r.Header, wantHeader) {
928 t.Errorf("Header = %#v; want %#v", r.Header, wantHeader)
929 }
930 })
931 }
932
933
934 func TestServer_Request_CookieConcat(t *testing.T) {
935 const host = "example.com"
936 testServerRequest(t, func(st *serverTester) {
937 st.bodylessReq1(
938 ":authority", host,
939 "cookie", "a=b",
940 "cookie", "c=d",
941 "cookie", "e=f",
942 )
943 }, func(r *http.Request) {
944 const want = "a=b; c=d; e=f"
945 if got := r.Header.Get("Cookie"); got != want {
946 t.Errorf("Cookie = %q; want %q", got, want)
947 }
948 })
949 }
950
951 func TestServer_Request_Reject_CapitalHeader(t *testing.T) {
952 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("UPPER", "v") })
953 }
954
955 func TestServer_Request_Reject_HeaderFieldNameColon(t *testing.T) {
956 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has:colon", "v") })
957 }
958
959 func TestServer_Request_Reject_HeaderFieldNameNULL(t *testing.T) {
960 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("has\x00null", "v") })
961 }
962
963 func TestServer_Request_Reject_HeaderFieldNameEmpty(t *testing.T) {
964 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("", "v") })
965 }
966
967 func TestServer_Request_Reject_HeaderFieldValueNewline(t *testing.T) {
968 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\nnewline") })
969 }
970
971 func TestServer_Request_Reject_HeaderFieldValueCR(t *testing.T) {
972 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\rcarriage") })
973 }
974
975 func TestServer_Request_Reject_HeaderFieldValueDEL(t *testing.T) {
976 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1("foo", "has\x7fdel") })
977 }
978
979 func TestServer_Request_Reject_Pseudo_Missing_method(t *testing.T) {
980 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":method", "") })
981 }
982
983 func TestServer_Request_Reject_Pseudo_ExactlyOne(t *testing.T) {
984
985
986 testRejectRequest(t, func(st *serverTester) {
987 st.addLogFilter("duplicate pseudo-header")
988 st.bodylessReq1(":method", "GET", ":method", "POST")
989 })
990 }
991
992 func TestServer_Request_Reject_Pseudo_AfterRegular(t *testing.T) {
993
994
995
996
997
998
999 testRejectRequest(t, func(st *serverTester) {
1000 st.addLogFilter("pseudo-header after regular header")
1001 var buf bytes.Buffer
1002 enc := hpack.NewEncoder(&buf)
1003 enc.WriteField(hpack.HeaderField{Name: ":method", Value: "GET"})
1004 enc.WriteField(hpack.HeaderField{Name: "regular", Value: "foobar"})
1005 enc.WriteField(hpack.HeaderField{Name: ":path", Value: "/"})
1006 enc.WriteField(hpack.HeaderField{Name: ":scheme", Value: "https"})
1007 st.writeHeaders(HeadersFrameParam{
1008 StreamID: 1,
1009 BlockFragment: buf.Bytes(),
1010 EndStream: true,
1011 EndHeaders: true,
1012 })
1013 })
1014 }
1015
1016 func TestServer_Request_Reject_Pseudo_Missing_path(t *testing.T) {
1017 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":path", "") })
1018 }
1019
1020 func TestServer_Request_Reject_Pseudo_Missing_scheme(t *testing.T) {
1021 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "") })
1022 }
1023
1024 func TestServer_Request_Reject_Pseudo_scheme_invalid(t *testing.T) {
1025 testRejectRequest(t, func(st *serverTester) { st.bodylessReq1(":scheme", "bogus") })
1026 }
1027
1028 func TestServer_Request_Reject_Pseudo_Unknown(t *testing.T) {
1029 testRejectRequest(t, func(st *serverTester) {
1030 st.addLogFilter(`invalid pseudo-header ":unknown_thing"`)
1031 st.bodylessReq1(":unknown_thing", "")
1032 })
1033 }
1034
1035 func testRejectRequest(t *testing.T, send func(*serverTester)) {
1036 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1037 t.Error("server request made it to handler; should've been rejected")
1038 })
1039 defer st.Close()
1040
1041 st.greet()
1042 send(st)
1043 st.wantRSTStream(1, ErrCodeProtocol)
1044 }
1045
1046 func newServerTesterForError(t *testing.T) *serverTester {
1047 t.Helper()
1048 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1049 t.Error("server request made it to handler; should've been rejected")
1050 }, optQuiet)
1051 st.greet()
1052 return st
1053 }
1054
1055
1056
1057
1058 func TestRejectFrameOnIdle_WindowUpdate(t *testing.T) {
1059 st := newServerTesterForError(t)
1060 st.fr.WriteWindowUpdate(123, 456)
1061 st.wantGoAway(123, ErrCodeProtocol)
1062 }
1063 func TestRejectFrameOnIdle_Data(t *testing.T) {
1064 st := newServerTesterForError(t)
1065 st.fr.WriteData(123, true, nil)
1066 st.wantGoAway(123, ErrCodeProtocol)
1067 }
1068 func TestRejectFrameOnIdle_RSTStream(t *testing.T) {
1069 st := newServerTesterForError(t)
1070 st.fr.WriteRSTStream(123, ErrCodeCancel)
1071 st.wantGoAway(123, ErrCodeProtocol)
1072 }
1073
1074 func TestServer_Request_Connect(t *testing.T) {
1075 testServerRequest(t, func(st *serverTester) {
1076 st.writeHeaders(HeadersFrameParam{
1077 StreamID: 1,
1078 BlockFragment: st.encodeHeaderRaw(
1079 ":method", "CONNECT",
1080 ":authority", "example.com:123",
1081 ),
1082 EndStream: true,
1083 EndHeaders: true,
1084 })
1085 }, func(r *http.Request) {
1086 if g, w := r.Method, "CONNECT"; g != w {
1087 t.Errorf("Method = %q; want %q", g, w)
1088 }
1089 if g, w := r.RequestURI, "example.com:123"; g != w {
1090 t.Errorf("RequestURI = %q; want %q", g, w)
1091 }
1092 if g, w := r.URL.Host, "example.com:123"; g != w {
1093 t.Errorf("URL.Host = %q; want %q", g, w)
1094 }
1095 })
1096 }
1097
1098 func TestServer_Request_Connect_InvalidPath(t *testing.T) {
1099 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1100 st.writeHeaders(HeadersFrameParam{
1101 StreamID: 1,
1102 BlockFragment: st.encodeHeaderRaw(
1103 ":method", "CONNECT",
1104 ":authority", "example.com:123",
1105 ":path", "/bogus",
1106 ),
1107 EndStream: true,
1108 EndHeaders: true,
1109 })
1110 })
1111 }
1112
1113 func TestServer_Request_Connect_InvalidScheme(t *testing.T) {
1114 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1115 st.writeHeaders(HeadersFrameParam{
1116 StreamID: 1,
1117 BlockFragment: st.encodeHeaderRaw(
1118 ":method", "CONNECT",
1119 ":authority", "example.com:123",
1120 ":scheme", "https",
1121 ),
1122 EndStream: true,
1123 EndHeaders: true,
1124 })
1125 })
1126 }
1127
1128 func TestServer_Ping(t *testing.T) {
1129 st := newServerTester(t, nil)
1130 defer st.Close()
1131 st.greet()
1132
1133
1134 ackPingData := [8]byte{1, 2, 4, 8, 16, 32, 64, 128}
1135 if err := st.fr.WritePing(true, ackPingData); err != nil {
1136 t.Fatal(err)
1137 }
1138
1139
1140 pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1141 if err := st.fr.WritePing(false, pingData); err != nil {
1142 t.Fatal(err)
1143 }
1144
1145 pf := readFrame[*PingFrame](t, st)
1146 if !pf.Flags.Has(FlagPingAck) {
1147 t.Error("response ping doesn't have ACK set")
1148 }
1149 if pf.Data != pingData {
1150 t.Errorf("response ping has data %q; want %q", pf.Data, pingData)
1151 }
1152 }
1153
1154 type filterListener struct {
1155 net.Listener
1156 accept func(conn net.Conn) (net.Conn, error)
1157 }
1158
1159 func (l *filterListener) Accept() (net.Conn, error) {
1160 c, err := l.Listener.Accept()
1161 if err != nil {
1162 return nil, err
1163 }
1164 return l.accept(c)
1165 }
1166
1167 func TestServer_MaxQueuedControlFrames(t *testing.T) {
1168
1169 disableGoroutineTracking(t)
1170
1171 st := newServerTester(t, nil)
1172 st.greet()
1173
1174 st.cc.(*synctestNetConn).SetReadBufferSize(0)
1175 st.cc.(*synctestNetConn).autoWait = false
1176
1177
1178
1179 const extraPings = 2
1180 for i := 0; i < maxQueuedControlFrames+extraPings; i++ {
1181 pingData := [8]byte{1, 2, 3, 4, 5, 6, 7, 8}
1182 st.fr.WritePing(false, pingData)
1183 }
1184 st.group.Wait()
1185
1186
1187
1188 st.cc.(*synctestNetConn).SetReadBufferSize(math.MaxInt)
1189
1190 st.advance(goAwayTimeout)
1191
1192 for i := 0; i < 10; i++ {
1193 if st.readFrame() == nil {
1194 break
1195 }
1196 }
1197 st.wantClosed()
1198 }
1199
1200 func TestServer_RejectsLargeFrames(t *testing.T) {
1201 if runtime.GOOS == "windows" || runtime.GOOS == "plan9" || runtime.GOOS == "zos" {
1202 t.Skip("see golang.org/issue/13434, golang.org/issue/37321")
1203 }
1204 st := newServerTester(t, nil)
1205 defer st.Close()
1206 st.greet()
1207
1208
1209
1210
1211 st.fr.WriteRawFrame(0xff, 0, 0, make([]byte, defaultMaxReadFrameSize+1))
1212
1213 st.wantGoAway(0, ErrCodeFrameSize)
1214 st.advance(goAwayTimeout)
1215 st.wantClosed()
1216 }
1217
1218 func TestServer_Handler_Sends_WindowUpdate(t *testing.T) {
1219
1220
1221
1222
1223 const windowSize = 65535 * 2
1224 puppet := newHandlerPuppet()
1225 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1226 puppet.act(w, r)
1227 }, func(s *Server) {
1228 s.MaxUploadBufferPerConnection = windowSize
1229 s.MaxUploadBufferPerStream = windowSize
1230 })
1231 defer st.Close()
1232 defer puppet.done()
1233
1234 st.greet()
1235 st.writeHeaders(HeadersFrameParam{
1236 StreamID: 1,
1237 BlockFragment: st.encodeHeader(":method", "POST"),
1238 EndStream: false,
1239 EndHeaders: true,
1240 })
1241
1242
1243
1244
1245 data := make([]byte, windowSize)
1246 st.writeData(1, false, data[:1024])
1247 puppet.do(readBodyHandler(t, string(data[:1024])))
1248
1249
1250
1251 st.writeData(1, false, data[1024:])
1252 st.wantWindowUpdate(0, 1024)
1253 st.wantWindowUpdate(1, 1024)
1254
1255
1256 puppet.do(readBodyHandler(t, string(data[1024:])))
1257 st.wantWindowUpdate(0, windowSize-1024)
1258 st.wantWindowUpdate(1, windowSize-1024)
1259 }
1260
1261
1262
1263 func TestServer_Handler_Sends_WindowUpdate_Padding(t *testing.T) {
1264 const windowSize = 65535 * 2
1265 puppet := newHandlerPuppet()
1266 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1267 puppet.act(w, r)
1268 }, func(s *Server) {
1269 s.MaxUploadBufferPerConnection = windowSize
1270 s.MaxUploadBufferPerStream = windowSize
1271 })
1272 defer st.Close()
1273 defer puppet.done()
1274
1275 st.greet()
1276 st.writeHeaders(HeadersFrameParam{
1277 StreamID: 1,
1278 BlockFragment: st.encodeHeader(":method", "POST"),
1279 EndStream: false,
1280 EndHeaders: true,
1281 })
1282
1283
1284
1285
1286 data := make([]byte, windowSize/2)
1287 pad := make([]byte, 4)
1288 st.writeDataPadded(1, false, data, pad)
1289
1290
1291
1292
1293 puppet.do(readBodyHandler(t, string(data)))
1294 st.wantWindowUpdate(0, uint32(len(data)+1+len(pad)))
1295 st.wantWindowUpdate(1, uint32(len(data)+1+len(pad)))
1296 }
1297
1298 func TestServer_Send_GoAway_After_Bogus_WindowUpdate(t *testing.T) {
1299 st := newServerTester(t, nil)
1300 defer st.Close()
1301 st.greet()
1302 if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
1303 t.Fatal(err)
1304 }
1305 st.wantGoAway(0, ErrCodeFlowControl)
1306 }
1307
1308 func TestServer_Send_RstStream_After_Bogus_WindowUpdate(t *testing.T) {
1309 inHandler := make(chan bool)
1310 blockHandler := make(chan bool)
1311 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1312 inHandler <- true
1313 <-blockHandler
1314 })
1315 defer st.Close()
1316 defer close(blockHandler)
1317 st.greet()
1318 st.writeHeaders(HeadersFrameParam{
1319 StreamID: 1,
1320 BlockFragment: st.encodeHeader(":method", "POST"),
1321 EndStream: false,
1322 EndHeaders: true,
1323 })
1324 <-inHandler
1325
1326 if err := st.fr.WriteWindowUpdate(1, 1<<31-1); err != nil {
1327 t.Fatal(err)
1328 }
1329 st.wantRSTStream(1, ErrCodeFlowControl)
1330 }
1331
1332
1333
1334
1335 func testServerPostUnblock(t *testing.T,
1336 handler func(http.ResponseWriter, *http.Request) error,
1337 fn func(*serverTester),
1338 checkErr func(error),
1339 otherHeaders ...string) {
1340 inHandler := make(chan bool)
1341 errc := make(chan error, 1)
1342 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1343 inHandler <- true
1344 errc <- handler(w, r)
1345 })
1346 defer st.Close()
1347 st.greet()
1348 st.writeHeaders(HeadersFrameParam{
1349 StreamID: 1,
1350 BlockFragment: st.encodeHeader(append([]string{":method", "POST"}, otherHeaders...)...),
1351 EndStream: false,
1352 EndHeaders: true,
1353 })
1354 <-inHandler
1355 fn(st)
1356 err := <-errc
1357 if checkErr != nil {
1358 checkErr(err)
1359 }
1360 }
1361
1362 func TestServer_RSTStream_Unblocks_Read(t *testing.T) {
1363 testServerPostUnblock(t,
1364 func(w http.ResponseWriter, r *http.Request) (err error) {
1365 _, err = r.Body.Read(make([]byte, 1))
1366 return
1367 },
1368 func(st *serverTester) {
1369 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1370 t.Fatal(err)
1371 }
1372 },
1373 func(err error) {
1374 want := StreamError{StreamID: 0x1, Code: 0x8}
1375 if !reflect.DeepEqual(err, want) {
1376 t.Errorf("Read error = %v; want %v", err, want)
1377 }
1378 },
1379 )
1380 }
1381
1382 func TestServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1383
1384
1385 n := 50
1386 if testing.Short() {
1387 n = 5
1388 }
1389 for i := 0; i < n; i++ {
1390 testServer_RSTStream_Unblocks_Header_Write(t)
1391 }
1392 }
1393
1394 func testServer_RSTStream_Unblocks_Header_Write(t *testing.T) {
1395 inHandler := make(chan bool, 1)
1396 unblockHandler := make(chan bool, 1)
1397 headerWritten := make(chan bool, 1)
1398 wroteRST := make(chan bool, 1)
1399
1400 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1401 inHandler <- true
1402 <-wroteRST
1403 w.Header().Set("foo", "bar")
1404 w.WriteHeader(200)
1405 w.(http.Flusher).Flush()
1406 headerWritten <- true
1407 <-unblockHandler
1408 })
1409 defer st.Close()
1410
1411 st.greet()
1412 st.writeHeaders(HeadersFrameParam{
1413 StreamID: 1,
1414 BlockFragment: st.encodeHeader(":method", "POST"),
1415 EndStream: false,
1416 EndHeaders: true,
1417 })
1418 <-inHandler
1419 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1420 t.Fatal(err)
1421 }
1422 wroteRST <- true
1423 st.awaitIdle()
1424 <-headerWritten
1425 unblockHandler <- true
1426 }
1427
1428 func TestServer_DeadConn_Unblocks_Read(t *testing.T) {
1429 testServerPostUnblock(t,
1430 func(w http.ResponseWriter, r *http.Request) (err error) {
1431 _, err = r.Body.Read(make([]byte, 1))
1432 return
1433 },
1434 func(st *serverTester) { st.cc.Close() },
1435 func(err error) {
1436 if err == nil {
1437 t.Error("unexpected nil error from Request.Body.Read")
1438 }
1439 },
1440 )
1441 }
1442
1443 var blockUntilClosed = func(w http.ResponseWriter, r *http.Request) error {
1444 <-w.(http.CloseNotifier).CloseNotify()
1445 return nil
1446 }
1447
1448 func TestServer_CloseNotify_After_RSTStream(t *testing.T) {
1449 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1450 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
1451 t.Fatal(err)
1452 }
1453 }, nil)
1454 }
1455
1456 func TestServer_CloseNotify_After_ConnClose(t *testing.T) {
1457 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) { st.cc.Close() }, nil)
1458 }
1459
1460
1461
1462
1463 func TestServer_CloseNotify_After_StreamError(t *testing.T) {
1464 testServerPostUnblock(t, blockUntilClosed, func(st *serverTester) {
1465
1466 st.writeData(1, true, []byte("1234"))
1467 }, nil, "content-length", "3")
1468 }
1469
1470 func TestServer_StateTransitions(t *testing.T) {
1471 var st *serverTester
1472 inHandler := make(chan bool)
1473 writeData := make(chan bool)
1474 leaveHandler := make(chan bool)
1475 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1476 inHandler <- true
1477 if st.stream(1) == nil {
1478 t.Errorf("nil stream 1 in handler")
1479 }
1480 if got, want := st.streamState(1), stateOpen; got != want {
1481 t.Errorf("in handler, state is %v; want %v", got, want)
1482 }
1483 writeData <- true
1484 if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
1485 t.Errorf("body read = %d, %v; want 0, EOF", n, err)
1486 }
1487 if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
1488 t.Errorf("in handler, state is %v; want %v", got, want)
1489 }
1490
1491 <-leaveHandler
1492 })
1493 st.greet()
1494 if st.stream(1) != nil {
1495 t.Fatal("stream 1 should be empty")
1496 }
1497 if got := st.streamState(1); got != stateIdle {
1498 t.Fatalf("stream 1 should be idle; got %v", got)
1499 }
1500
1501 st.writeHeaders(HeadersFrameParam{
1502 StreamID: 1,
1503 BlockFragment: st.encodeHeader(":method", "POST"),
1504 EndStream: false,
1505 EndHeaders: true,
1506 })
1507 <-inHandler
1508 <-writeData
1509 st.writeData(1, true, nil)
1510
1511 leaveHandler <- true
1512 st.wantHeaders(wantHeader{
1513 streamID: 1,
1514 endStream: true,
1515 })
1516
1517 if got, want := st.streamState(1), stateClosed; got != want {
1518 t.Errorf("at end, state is %v; want %v", got, want)
1519 }
1520 if st.stream(1) != nil {
1521 t.Fatal("at end, stream 1 should be gone")
1522 }
1523 }
1524
1525
1526 func TestServer_Rejects_HeadersNoEnd_Then_Headers(t *testing.T) {
1527 st := newServerTesterForError(t)
1528 st.writeHeaders(HeadersFrameParam{
1529 StreamID: 1,
1530 BlockFragment: st.encodeHeader(),
1531 EndStream: true,
1532 EndHeaders: false,
1533 })
1534 st.writeHeaders(HeadersFrameParam{
1535 StreamID: 3,
1536 BlockFragment: st.encodeHeader(),
1537 EndStream: true,
1538 EndHeaders: true,
1539 })
1540 st.wantGoAway(0, ErrCodeProtocol)
1541 }
1542
1543
1544 func TestServer_Rejects_HeadersNoEnd_Then_Ping(t *testing.T) {
1545 st := newServerTesterForError(t)
1546 st.writeHeaders(HeadersFrameParam{
1547 StreamID: 1,
1548 BlockFragment: st.encodeHeader(),
1549 EndStream: true,
1550 EndHeaders: false,
1551 })
1552 if err := st.fr.WritePing(false, [8]byte{}); err != nil {
1553 t.Fatal(err)
1554 }
1555 st.wantGoAway(0, ErrCodeProtocol)
1556 }
1557
1558
1559 func TestServer_Rejects_HeadersEnd_Then_Continuation(t *testing.T) {
1560 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, optQuiet)
1561 st.greet()
1562 st.writeHeaders(HeadersFrameParam{
1563 StreamID: 1,
1564 BlockFragment: st.encodeHeader(),
1565 EndStream: true,
1566 EndHeaders: true,
1567 })
1568 st.wantHeaders(wantHeader{
1569 streamID: 1,
1570 endStream: true,
1571 })
1572 if err := st.fr.WriteContinuation(1, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1573 t.Fatal(err)
1574 }
1575 st.wantGoAway(1, ErrCodeProtocol)
1576 }
1577
1578
1579 func TestServer_Rejects_HeadersNoEnd_Then_ContinuationWrongStream(t *testing.T) {
1580 st := newServerTesterForError(t)
1581 st.writeHeaders(HeadersFrameParam{
1582 StreamID: 1,
1583 BlockFragment: st.encodeHeader(),
1584 EndStream: true,
1585 EndHeaders: false,
1586 })
1587 if err := st.fr.WriteContinuation(3, true, encodeHeaderNoImplicit(t, "foo", "bar")); err != nil {
1588 t.Fatal(err)
1589 }
1590 st.wantGoAway(0, ErrCodeProtocol)
1591 }
1592
1593
1594 func TestServer_Rejects_Headers0(t *testing.T) {
1595 st := newServerTesterForError(t)
1596 st.fr.AllowIllegalWrites = true
1597 st.writeHeaders(HeadersFrameParam{
1598 StreamID: 0,
1599 BlockFragment: st.encodeHeader(),
1600 EndStream: true,
1601 EndHeaders: true,
1602 })
1603 st.wantGoAway(0, ErrCodeProtocol)
1604 }
1605
1606
1607 func TestServer_Rejects_Continuation0(t *testing.T) {
1608 st := newServerTesterForError(t)
1609 st.fr.AllowIllegalWrites = true
1610 if err := st.fr.WriteContinuation(0, true, st.encodeHeader()); err != nil {
1611 t.Fatal(err)
1612 }
1613 st.wantGoAway(0, ErrCodeProtocol)
1614 }
1615
1616
1617 func TestServer_Rejects_Priority0(t *testing.T) {
1618 st := newServerTesterForError(t)
1619 st.fr.AllowIllegalWrites = true
1620 st.writePriority(0, PriorityParam{StreamDep: 1})
1621 st.wantGoAway(0, ErrCodeProtocol)
1622 }
1623
1624
1625 func TestServer_Rejects_HeadersSelfDependence(t *testing.T) {
1626 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1627 st.fr.AllowIllegalWrites = true
1628 st.writeHeaders(HeadersFrameParam{
1629 StreamID: 1,
1630 BlockFragment: st.encodeHeader(),
1631 EndStream: true,
1632 EndHeaders: true,
1633 Priority: PriorityParam{StreamDep: 1},
1634 })
1635 })
1636 }
1637
1638
1639 func TestServer_Rejects_PrioritySelfDependence(t *testing.T) {
1640 testServerRejectsStream(t, ErrCodeProtocol, func(st *serverTester) {
1641 st.fr.AllowIllegalWrites = true
1642 st.writePriority(1, PriorityParam{StreamDep: 1})
1643 })
1644 }
1645
1646 func TestServer_Rejects_PushPromise(t *testing.T) {
1647 st := newServerTesterForError(t)
1648 pp := PushPromiseParam{
1649 StreamID: 1,
1650 PromiseID: 3,
1651 }
1652 if err := st.fr.WritePushPromise(pp); err != nil {
1653 t.Fatal(err)
1654 }
1655 st.wantGoAway(1, ErrCodeProtocol)
1656 }
1657
1658
1659
1660 func testServerRejectsStream(t *testing.T, code ErrCode, writeReq func(*serverTester)) {
1661 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
1662 defer st.Close()
1663 st.greet()
1664 writeReq(st)
1665 st.wantRSTStream(1, code)
1666 }
1667
1668
1669
1670
1671 func testServerRequest(t *testing.T, writeReq func(*serverTester), checkReq func(*http.Request)) {
1672 gotReq := make(chan bool, 1)
1673 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
1674 if r.Body == nil {
1675 t.Fatal("nil Body")
1676 }
1677 checkReq(r)
1678 gotReq <- true
1679 })
1680 defer st.Close()
1681
1682 st.greet()
1683 writeReq(st)
1684 <-gotReq
1685 }
1686
1687 func getSlash(st *serverTester) { st.bodylessReq1() }
1688
1689 func TestServer_Response_NoData(t *testing.T) {
1690 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1691
1692 return nil
1693 }, func(st *serverTester) {
1694 getSlash(st)
1695 st.wantHeaders(wantHeader{
1696 streamID: 1,
1697 endStream: true,
1698 })
1699 })
1700 }
1701
1702 func TestServer_Response_NoData_Header_FooBar(t *testing.T) {
1703 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1704 w.Header().Set("Foo-Bar", "some-value")
1705 return nil
1706 }, func(st *serverTester) {
1707 getSlash(st)
1708 st.wantHeaders(wantHeader{
1709 streamID: 1,
1710 endStream: true,
1711 header: http.Header{
1712 ":status": []string{"200"},
1713 "foo-bar": []string{"some-value"},
1714 "content-length": []string{"0"},
1715 },
1716 })
1717 })
1718 }
1719
1720
1721
1722 func TestServerIgnoresContentLengthSignWhenWritingChunks(t *testing.T) {
1723 tests := []struct {
1724 name string
1725 cl string
1726 wantCL string
1727 }{
1728 {
1729 name: "proper content-length",
1730 cl: "3",
1731 wantCL: "3",
1732 },
1733 {
1734 name: "ignore cl with plus sign",
1735 cl: "+3",
1736 wantCL: "0",
1737 },
1738 {
1739 name: "ignore cl with minus sign",
1740 cl: "-3",
1741 wantCL: "0",
1742 },
1743 {
1744 name: "max int64, for safe uint64->int64 conversion",
1745 cl: "9223372036854775807",
1746 wantCL: "9223372036854775807",
1747 },
1748 {
1749 name: "overflows int64, so ignored",
1750 cl: "9223372036854775808",
1751 wantCL: "0",
1752 },
1753 }
1754
1755 for _, tt := range tests {
1756 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1757 w.Header().Set("content-length", tt.cl)
1758 return nil
1759 }, func(st *serverTester) {
1760 getSlash(st)
1761 st.wantHeaders(wantHeader{
1762 streamID: 1,
1763 endStream: true,
1764 header: http.Header{
1765 ":status": []string{"200"},
1766 "content-length": []string{tt.wantCL},
1767 },
1768 })
1769 })
1770 }
1771 }
1772
1773
1774
1775 func TestServerRejectsContentLengthWithSignNewRequests(t *testing.T) {
1776 tests := []struct {
1777 name string
1778 cl string
1779 wantCL int64
1780 }{
1781 {
1782 name: "proper content-length",
1783 cl: "3",
1784 wantCL: 3,
1785 },
1786 {
1787 name: "ignore cl with plus sign",
1788 cl: "+3",
1789 wantCL: 0,
1790 },
1791 {
1792 name: "ignore cl with minus sign",
1793 cl: "-3",
1794 wantCL: 0,
1795 },
1796 {
1797 name: "max int64, for safe uint64->int64 conversion",
1798 cl: "9223372036854775807",
1799 wantCL: 9223372036854775807,
1800 },
1801 {
1802 name: "overflows int64, so ignored",
1803 cl: "9223372036854775808",
1804 wantCL: 0,
1805 },
1806 }
1807
1808 for _, tt := range tests {
1809 tt := tt
1810 t.Run(tt.name, func(t *testing.T) {
1811 writeReq := func(st *serverTester) {
1812 st.writeHeaders(HeadersFrameParam{
1813 StreamID: 1,
1814 BlockFragment: st.encodeHeader("content-length", tt.cl),
1815 EndStream: false,
1816 EndHeaders: true,
1817 })
1818 st.writeData(1, false, []byte(""))
1819 }
1820 checkReq := func(r *http.Request) {
1821 if r.ContentLength != tt.wantCL {
1822 t.Fatalf("Got: %q\nWant: %q", r.ContentLength, tt.wantCL)
1823 }
1824 }
1825 testServerRequest(t, writeReq, checkReq)
1826 })
1827 }
1828 }
1829
1830 func TestServer_Response_Data_Sniff_DoesntOverride(t *testing.T) {
1831 const msg = "<html>this is HTML."
1832 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1833 w.Header().Set("Content-Type", "foo/bar")
1834 io.WriteString(w, msg)
1835 return nil
1836 }, func(st *serverTester) {
1837 getSlash(st)
1838 st.wantHeaders(wantHeader{
1839 streamID: 1,
1840 endStream: false,
1841 header: http.Header{
1842 ":status": []string{"200"},
1843 "content-type": []string{"foo/bar"},
1844 "content-length": []string{strconv.Itoa(len(msg))},
1845 },
1846 })
1847 st.wantData(wantData{
1848 streamID: 1,
1849 endStream: true,
1850 data: []byte(msg),
1851 })
1852 })
1853 }
1854
1855 func TestServer_Response_TransferEncoding_chunked(t *testing.T) {
1856 const msg = "hi"
1857 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1858 w.Header().Set("Transfer-Encoding", "chunked")
1859 io.WriteString(w, msg)
1860 return nil
1861 }, func(st *serverTester) {
1862 getSlash(st)
1863 st.wantHeaders(wantHeader{
1864 streamID: 1,
1865 endStream: false,
1866 header: http.Header{
1867 ":status": []string{"200"},
1868 "content-type": []string{"text/plain; charset=utf-8"},
1869 "content-length": []string{strconv.Itoa(len(msg))},
1870 },
1871 })
1872 })
1873 }
1874
1875
1876 func TestServer_Response_Data_IgnoreHeaderAfterWrite_After(t *testing.T) {
1877 const msg = "<html>this is HTML."
1878 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1879 io.WriteString(w, msg)
1880 w.Header().Set("foo", "should be ignored")
1881 return nil
1882 }, func(st *serverTester) {
1883 getSlash(st)
1884 st.wantHeaders(wantHeader{
1885 streamID: 1,
1886 endStream: false,
1887 header: http.Header{
1888 ":status": []string{"200"},
1889 "content-type": []string{"text/html; charset=utf-8"},
1890 "content-length": []string{strconv.Itoa(len(msg))},
1891 },
1892 })
1893 })
1894 }
1895
1896
1897 func TestServer_Response_Data_IgnoreHeaderAfterWrite_Overwrite(t *testing.T) {
1898 const msg = "<html>this is HTML."
1899 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1900 w.Header().Set("foo", "proper value")
1901 io.WriteString(w, msg)
1902 w.Header().Set("foo", "should be ignored")
1903 return nil
1904 }, func(st *serverTester) {
1905 getSlash(st)
1906 st.wantHeaders(wantHeader{
1907 streamID: 1,
1908 endStream: false,
1909 header: http.Header{
1910 ":status": []string{"200"},
1911 "foo": []string{"proper value"},
1912 "content-type": []string{"text/html; charset=utf-8"},
1913 "content-length": []string{strconv.Itoa(len(msg))},
1914 },
1915 })
1916 })
1917 }
1918
1919 func TestServer_Response_Data_SniffLenType(t *testing.T) {
1920 const msg = "<html>this is HTML."
1921 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1922 io.WriteString(w, msg)
1923 return nil
1924 }, func(st *serverTester) {
1925 getSlash(st)
1926 st.wantHeaders(wantHeader{
1927 streamID: 1,
1928 endStream: false,
1929 header: http.Header{
1930 ":status": []string{"200"},
1931 "content-type": []string{"text/html; charset=utf-8"},
1932 "content-length": []string{strconv.Itoa(len(msg))},
1933 },
1934 })
1935 st.wantData(wantData{
1936 streamID: 1,
1937 endStream: true,
1938 data: []byte(msg),
1939 })
1940 })
1941 }
1942
1943 func TestServer_Response_Header_Flush_MidWrite(t *testing.T) {
1944 const msg = "<html>this is HTML"
1945 const msg2 = ", and this is the next chunk"
1946 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1947 io.WriteString(w, msg)
1948 w.(http.Flusher).Flush()
1949 io.WriteString(w, msg2)
1950 return nil
1951 }, func(st *serverTester) {
1952 getSlash(st)
1953 st.wantHeaders(wantHeader{
1954 streamID: 1,
1955 endStream: false,
1956 header: http.Header{
1957 ":status": []string{"200"},
1958 "content-type": []string{"text/html; charset=utf-8"},
1959
1960 },
1961 })
1962 st.wantData(wantData{
1963 streamID: 1,
1964 endStream: false,
1965 data: []byte(msg),
1966 })
1967 st.wantData(wantData{
1968 streamID: 1,
1969 endStream: true,
1970 data: []byte(msg2),
1971 })
1972 })
1973 }
1974
1975 func TestServer_Response_LargeWrite(t *testing.T) {
1976 const size = 1 << 20
1977 const maxFrameSize = 16 << 10
1978 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
1979 n, err := w.Write(bytes.Repeat([]byte("a"), size))
1980 if err != nil {
1981 return fmt.Errorf("Write error: %v", err)
1982 }
1983 if n != size {
1984 return fmt.Errorf("wrong size %d from Write", n)
1985 }
1986 return nil
1987 }, func(st *serverTester) {
1988 if err := st.fr.WriteSettings(
1989 Setting{SettingInitialWindowSize, 0},
1990 Setting{SettingMaxFrameSize, maxFrameSize},
1991 ); err != nil {
1992 t.Fatal(err)
1993 }
1994 st.wantSettingsAck()
1995
1996 getSlash(st)
1997
1998
1999 if err := st.fr.WriteWindowUpdate(1, size); err != nil {
2000 t.Fatal(err)
2001 }
2002
2003
2004 if err := st.fr.WriteWindowUpdate(0, size); err != nil {
2005 t.Fatal(err)
2006 }
2007 st.wantHeaders(wantHeader{
2008 streamID: 1,
2009 endStream: false,
2010 header: http.Header{
2011 ":status": []string{"200"},
2012 "content-type": []string{"text/plain; charset=utf-8"},
2013
2014 },
2015 })
2016 var bytes, frames int
2017 for {
2018 df := readFrame[*DataFrame](t, st)
2019 bytes += len(df.Data())
2020 frames++
2021 for _, b := range df.Data() {
2022 if b != 'a' {
2023 t.Fatal("non-'a' byte seen in DATA")
2024 }
2025 }
2026 if df.StreamEnded() {
2027 break
2028 }
2029 }
2030 if bytes != size {
2031 t.Errorf("Got %d bytes; want %d", bytes, size)
2032 }
2033 if want := int(size / maxFrameSize); frames < want || frames > want*2 {
2034 t.Errorf("Got %d frames; want %d", frames, size)
2035 }
2036 })
2037 }
2038
2039
2040 func TestServer_Response_LargeWrite_FlowControlled(t *testing.T) {
2041
2042
2043 reads := []int{123, 1, 13, 127}
2044 size := 0
2045 for _, n := range reads {
2046 size += n
2047 }
2048
2049 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2050 w.(http.Flusher).Flush()
2051 n, err := w.Write(bytes.Repeat([]byte("a"), size))
2052 if err != nil {
2053 return fmt.Errorf("Write error: %v", err)
2054 }
2055 if n != size {
2056 return fmt.Errorf("wrong size %d from Write", n)
2057 }
2058 return nil
2059 }, func(st *serverTester) {
2060
2061
2062 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, uint32(reads[0])}); err != nil {
2063 t.Fatal(err)
2064 }
2065 st.wantSettingsAck()
2066
2067 getSlash(st)
2068
2069 st.wantHeaders(wantHeader{
2070 streamID: 1,
2071 endStream: false,
2072 })
2073
2074 st.wantData(wantData{
2075 streamID: 1,
2076 endStream: false,
2077 size: reads[0],
2078 })
2079
2080 for i, quota := range reads[1:] {
2081 if err := st.fr.WriteWindowUpdate(1, uint32(quota)); err != nil {
2082 t.Fatal(err)
2083 }
2084 st.wantData(wantData{
2085 streamID: 1,
2086 endStream: i == len(reads[1:])-1,
2087 size: quota,
2088 })
2089 }
2090 })
2091 }
2092
2093
2094 func TestServer_Response_RST_Unblocks_LargeWrite(t *testing.T) {
2095 const size = 1 << 20
2096 const maxFrameSize = 16 << 10
2097 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2098 w.(http.Flusher).Flush()
2099 _, err := w.Write(bytes.Repeat([]byte("a"), size))
2100 if err == nil {
2101 return errors.New("unexpected nil error from Write in handler")
2102 }
2103 return nil
2104 }, func(st *serverTester) {
2105 if err := st.fr.WriteSettings(
2106 Setting{SettingInitialWindowSize, 0},
2107 Setting{SettingMaxFrameSize, maxFrameSize},
2108 ); err != nil {
2109 t.Fatal(err)
2110 }
2111 st.wantSettingsAck()
2112
2113 getSlash(st)
2114
2115 st.wantHeaders(wantHeader{
2116 streamID: 1,
2117 endStream: false,
2118 })
2119
2120 if err := st.fr.WriteRSTStream(1, ErrCodeCancel); err != nil {
2121 t.Fatal(err)
2122 }
2123 })
2124 }
2125
2126 func TestServer_Response_Empty_Data_Not_FlowControlled(t *testing.T) {
2127 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2128 w.(http.Flusher).Flush()
2129
2130 return nil
2131 }, func(st *serverTester) {
2132
2133 if err := st.fr.WriteSettings(Setting{SettingInitialWindowSize, 0}); err != nil {
2134 t.Fatal(err)
2135 }
2136 st.wantSettingsAck()
2137
2138 getSlash(st)
2139
2140 st.wantHeaders(wantHeader{
2141 streamID: 1,
2142 endStream: false,
2143 })
2144
2145 st.wantData(wantData{
2146 streamID: 1,
2147 endStream: true,
2148 size: 0,
2149 })
2150 })
2151 }
2152
2153 func TestServer_Response_Automatic100Continue(t *testing.T) {
2154 const msg = "foo"
2155 const reply = "bar"
2156 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2157 if v := r.Header.Get("Expect"); v != "" {
2158 t.Errorf("Expect header = %q; want empty", v)
2159 }
2160 buf := make([]byte, len(msg))
2161
2162 if n, err := io.ReadFull(r.Body, buf); err != nil || n != len(msg) || string(buf) != msg {
2163 return fmt.Errorf("ReadFull = %q, %v; want %q, nil", buf[:n], err, msg)
2164 }
2165 _, err := io.WriteString(w, reply)
2166 return err
2167 }, func(st *serverTester) {
2168 st.writeHeaders(HeadersFrameParam{
2169 StreamID: 1,
2170 BlockFragment: st.encodeHeader(":method", "POST", "expect", "100-Continue"),
2171 EndStream: false,
2172 EndHeaders: true,
2173 })
2174 st.wantHeaders(wantHeader{
2175 streamID: 1,
2176 endStream: false,
2177 header: http.Header{
2178 ":status": []string{"100"},
2179 },
2180 })
2181
2182
2183
2184 st.writeData(1, true, []byte(msg))
2185
2186 st.wantHeaders(wantHeader{
2187 streamID: 1,
2188 endStream: false,
2189 header: http.Header{
2190 ":status": []string{"200"},
2191 "content-type": []string{"text/plain; charset=utf-8"},
2192 "content-length": []string{strconv.Itoa(len(reply))},
2193 },
2194 })
2195
2196 st.wantData(wantData{
2197 streamID: 1,
2198 endStream: true,
2199 data: []byte(reply),
2200 })
2201 })
2202 }
2203
2204 func TestServer_HandlerWriteErrorOnDisconnect(t *testing.T) {
2205 errc := make(chan error, 1)
2206 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2207 p := []byte("some data.\n")
2208 for {
2209 _, err := w.Write(p)
2210 if err != nil {
2211 errc <- err
2212 return nil
2213 }
2214 }
2215 }, func(st *serverTester) {
2216 st.writeHeaders(HeadersFrameParam{
2217 StreamID: 1,
2218 BlockFragment: st.encodeHeader(),
2219 EndStream: false,
2220 EndHeaders: true,
2221 })
2222 st.wantHeaders(wantHeader{
2223 streamID: 1,
2224 endStream: false,
2225 })
2226
2227 st.cc.Close()
2228 _ = <-errc
2229 })
2230 }
2231
2232 func TestServer_Rejects_Too_Many_Streams(t *testing.T) {
2233 const testPath = "/some/path"
2234
2235 inHandler := make(chan uint32)
2236 leaveHandler := make(chan bool)
2237 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2238 id := w.(*responseWriter).rws.stream.id
2239 inHandler <- id
2240 if id == 1+(defaultMaxStreams+1)*2 && r.URL.Path != testPath {
2241 t.Errorf("decoded final path as %q; want %q", r.URL.Path, testPath)
2242 }
2243 <-leaveHandler
2244 })
2245 defer st.Close()
2246
2247
2248
2249 st.cc.(*synctestNetConn).autoWait = false
2250
2251 st.greet()
2252 nextStreamID := uint32(1)
2253 streamID := func() uint32 {
2254 defer func() { nextStreamID += 2 }()
2255 return nextStreamID
2256 }
2257 sendReq := func(id uint32, headers ...string) {
2258 st.writeHeaders(HeadersFrameParam{
2259 StreamID: id,
2260 BlockFragment: st.encodeHeader(headers...),
2261 EndStream: true,
2262 EndHeaders: true,
2263 })
2264 }
2265 for i := 0; i < defaultMaxStreams; i++ {
2266 sendReq(streamID())
2267 <-inHandler
2268 }
2269 defer func() {
2270 for i := 0; i < defaultMaxStreams; i++ {
2271 leaveHandler <- true
2272 }
2273 }()
2274
2275
2276
2277
2278 rejectID := streamID()
2279 headerBlock := st.encodeHeader(":path", testPath)
2280 frag1, frag2 := headerBlock[:3], headerBlock[3:]
2281 st.writeHeaders(HeadersFrameParam{
2282 StreamID: rejectID,
2283 BlockFragment: frag1,
2284 EndStream: true,
2285 EndHeaders: false,
2286 })
2287 if err := st.fr.WriteContinuation(rejectID, true, frag2); err != nil {
2288 t.Fatal(err)
2289 }
2290 st.sync()
2291 st.wantRSTStream(rejectID, ErrCodeProtocol)
2292
2293
2294 leaveHandler <- true
2295 st.sync()
2296 st.wantHeaders(wantHeader{
2297 streamID: 1,
2298 endStream: true,
2299 })
2300
2301
2302 goodID := streamID()
2303 sendReq(goodID, ":path", testPath)
2304 if got := <-inHandler; got != goodID {
2305 t.Errorf("Got stream %d; want %d", got, goodID)
2306 }
2307 }
2308
2309
2310 func TestServer_Response_ManyHeaders_With_Continuation(t *testing.T) {
2311 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2312 h := w.Header()
2313 for i := 0; i < 5000; i++ {
2314 h.Set(fmt.Sprintf("x-header-%d", i), fmt.Sprintf("x-value-%d", i))
2315 }
2316 return nil
2317 }, func(st *serverTester) {
2318 getSlash(st)
2319 hf := readFrame[*HeadersFrame](t, st)
2320 if hf.HeadersEnded() {
2321 t.Fatal("got unwanted END_HEADERS flag")
2322 }
2323 n := 0
2324 for {
2325 n++
2326 cf := readFrame[*ContinuationFrame](t, st)
2327 if cf.HeadersEnded() {
2328 break
2329 }
2330 }
2331 if n < 5 {
2332 t.Errorf("Only got %d CONTINUATION frames; expected 5+ (currently 6)", n)
2333 }
2334 })
2335 }
2336
2337
2338
2339
2340
2341
2342
2343
2344 func TestServer_NoCrash_HandlerClose_Then_ClientClose(t *testing.T) {
2345 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2346
2347 return nil
2348 }, func(st *serverTester) {
2349 st.writeHeaders(HeadersFrameParam{
2350 StreamID: 1,
2351 BlockFragment: st.encodeHeader(),
2352 EndStream: false,
2353 EndHeaders: true,
2354 })
2355 st.wantHeaders(wantHeader{
2356 streamID: 1,
2357 endStream: true,
2358 })
2359
2360
2361
2362 st.wantRSTStream(1, ErrCodeNo)
2363
2364
2365
2366
2367
2368
2369 st.writeData(1, true, []byte("foo"))
2370
2371
2372
2373
2374
2375 st.wantRSTStream(1, ErrCodeStreamClosed)
2376
2377
2378
2379 st.wantFlowControlConsumed(0, 0)
2380
2381
2382
2383 var (
2384 panMu sync.Mutex
2385 panicVal interface{}
2386 )
2387
2388 testHookOnPanicMu.Lock()
2389 testHookOnPanic = func(sc *serverConn, pv interface{}) bool {
2390 panMu.Lock()
2391 panicVal = pv
2392 panMu.Unlock()
2393 return true
2394 }
2395 testHookOnPanicMu.Unlock()
2396
2397
2398 st.cc.Close()
2399 <-st.sc.doneServing
2400
2401 panMu.Lock()
2402 got := panicVal
2403 panMu.Unlock()
2404 if got != nil {
2405 t.Errorf("Got panic: %v", got)
2406 }
2407 })
2408 }
2409
2410 func TestServer_Rejects_TLS10(t *testing.T) { testRejectTLS(t, tls.VersionTLS10) }
2411 func TestServer_Rejects_TLS11(t *testing.T) { testRejectTLS(t, tls.VersionTLS11) }
2412
2413 func testRejectTLS(t *testing.T, version uint16) {
2414 st := newServerTester(t, nil, func(state *tls.ConnectionState) {
2415
2416
2417
2418 state.Version = version
2419 })
2420 defer st.Close()
2421 st.wantGoAway(0, ErrCodeInadequateSecurity)
2422 }
2423
2424 func TestServer_Rejects_TLSBadCipher(t *testing.T) {
2425 st := newServerTester(t, nil, func(state *tls.ConnectionState) {
2426 state.Version = tls.VersionTLS12
2427 state.CipherSuite = tls.TLS_RSA_WITH_RC4_128_SHA
2428 })
2429 defer st.Close()
2430 st.wantGoAway(0, ErrCodeInadequateSecurity)
2431 }
2432
2433 func TestServer_Advertises_Common_Cipher(t *testing.T) {
2434 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2435 }, func(srv *http.Server) {
2436
2437
2438 srv.TLSConfig = nil
2439 })
2440
2441
2442 const requiredSuite = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
2443 tlsConfig := tlsConfigInsecure.Clone()
2444 tlsConfig.MaxVersion = tls.VersionTLS12
2445 tlsConfig.CipherSuites = []uint16{requiredSuite}
2446 tr := &Transport{TLSClientConfig: tlsConfig}
2447 defer tr.CloseIdleConnections()
2448
2449 req, err := http.NewRequest("GET", ts.URL, nil)
2450 if err != nil {
2451 t.Fatal(err)
2452 }
2453 res, err := tr.RoundTrip(req)
2454 if err != nil {
2455 t.Fatal(err)
2456 }
2457 res.Body.Close()
2458 }
2459
2460
2461
2462 func testServerResponse(t testing.TB,
2463 handler func(http.ResponseWriter, *http.Request) error,
2464 client func(*serverTester),
2465 ) {
2466 errc := make(chan error, 1)
2467 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2468 if r.Body == nil {
2469 t.Fatal("nil Body")
2470 }
2471 err := handler(w, r)
2472 select {
2473 case errc <- err:
2474 default:
2475 t.Errorf("unexpected duplicate request")
2476 }
2477 })
2478 defer st.Close()
2479
2480 st.greet()
2481 client(st)
2482
2483 if err := <-errc; err != nil {
2484 t.Fatalf("Error in handler: %v", err)
2485 }
2486 }
2487
2488
2489
2490
2491 func readBodyHandler(t *testing.T, want string) func(w http.ResponseWriter, r *http.Request) {
2492 return func(w http.ResponseWriter, r *http.Request) {
2493 buf := make([]byte, len(want))
2494 _, err := io.ReadFull(r.Body, buf)
2495 if err != nil {
2496 t.Error(err)
2497 return
2498 }
2499 if string(buf) != want {
2500 t.Errorf("read %q; want %q", buf, want)
2501 }
2502 }
2503 }
2504
2505 func TestServer_MaxDecoderHeaderTableSize(t *testing.T) {
2506 wantHeaderTableSize := uint32(initialHeaderTableSize * 2)
2507 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) {
2508 s.MaxDecoderHeaderTableSize = wantHeaderTableSize
2509 })
2510 defer st.Close()
2511
2512 var advHeaderTableSize *uint32
2513 st.greetAndCheckSettings(func(s Setting) error {
2514 switch s.ID {
2515 case SettingHeaderTableSize:
2516 advHeaderTableSize = &s.Val
2517 }
2518 return nil
2519 })
2520
2521 if advHeaderTableSize == nil {
2522 t.Errorf("server didn't advertise a header table size")
2523 } else if got, want := *advHeaderTableSize, wantHeaderTableSize; got != want {
2524 t.Errorf("server advertised a header table size of %d, want %d", got, want)
2525 }
2526 }
2527
2528 func TestServer_MaxEncoderHeaderTableSize(t *testing.T) {
2529 wantHeaderTableSize := uint32(initialHeaderTableSize / 2)
2530 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {}, func(s *Server) {
2531 s.MaxEncoderHeaderTableSize = wantHeaderTableSize
2532 })
2533 defer st.Close()
2534
2535 st.greet()
2536
2537 if got, want := st.sc.hpackEncoder.MaxDynamicTableSize(), wantHeaderTableSize; got != want {
2538 t.Errorf("server encoder is using a header table size of %d, want %d", got, want)
2539 }
2540 }
2541
2542
2543 func TestServerDoS_MaxHeaderListSize(t *testing.T) {
2544 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
2545 defer st.Close()
2546
2547
2548 frameSize := defaultMaxReadFrameSize
2549 var advHeaderListSize *uint32
2550 st.greetAndCheckSettings(func(s Setting) error {
2551 switch s.ID {
2552 case SettingMaxFrameSize:
2553 if s.Val < minMaxFrameSize {
2554 frameSize = minMaxFrameSize
2555 } else if s.Val > maxFrameSize {
2556 frameSize = maxFrameSize
2557 } else {
2558 frameSize = int(s.Val)
2559 }
2560 case SettingMaxHeaderListSize:
2561 advHeaderListSize = &s.Val
2562 }
2563 return nil
2564 })
2565
2566 if advHeaderListSize == nil {
2567 t.Errorf("server didn't advertise a max header list size")
2568 } else if *advHeaderListSize == 0 {
2569 t.Errorf("server advertised a max header list size of 0")
2570 }
2571
2572 st.encodeHeaderField(":method", "GET")
2573 st.encodeHeaderField(":path", "/")
2574 st.encodeHeaderField(":scheme", "https")
2575 cookie := strings.Repeat("*", 4058)
2576 st.encodeHeaderField("cookie", cookie)
2577 st.writeHeaders(HeadersFrameParam{
2578 StreamID: 1,
2579 BlockFragment: st.headerBuf.Bytes(),
2580 EndStream: true,
2581 EndHeaders: false,
2582 })
2583
2584
2585
2586 st.headerBuf.Reset()
2587 st.encodeHeaderField("cookie", cookie)
2588
2589
2590 const size = 1 << 20
2591 b := bytes.Repeat(st.headerBuf.Bytes(), size/st.headerBuf.Len())
2592 for len(b) > 0 {
2593 chunk := b
2594 if len(chunk) > frameSize {
2595 chunk = chunk[:frameSize]
2596 }
2597 b = b[len(chunk):]
2598 st.fr.WriteContinuation(1, len(b) == 0, chunk)
2599 }
2600
2601 st.wantHeaders(wantHeader{
2602 streamID: 1,
2603 endStream: false,
2604 header: http.Header{
2605 ":status": []string{"431"},
2606 "content-type": []string{"text/html; charset=utf-8"},
2607 "content-length": []string{"63"},
2608 },
2609 })
2610 }
2611
2612 func TestServer_Response_Stream_With_Missing_Trailer(t *testing.T) {
2613 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2614 w.Header().Set("Trailer", "test-trailer")
2615 return nil
2616 }, func(st *serverTester) {
2617 getSlash(st)
2618 st.wantHeaders(wantHeader{
2619 streamID: 1,
2620 endStream: false,
2621 })
2622 st.wantData(wantData{
2623 streamID: 1,
2624 endStream: true,
2625 size: 0,
2626 })
2627 })
2628 }
2629
2630 func TestCompressionErrorOnWrite(t *testing.T) {
2631 const maxStrLen = 8 << 10
2632 var serverConfig *http.Server
2633 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2634
2635 }, func(s *http.Server) {
2636 serverConfig = s
2637 serverConfig.MaxHeaderBytes = maxStrLen
2638 })
2639 st.addLogFilter("connection error: COMPRESSION_ERROR")
2640 defer st.Close()
2641 st.greet()
2642
2643 maxAllowed := st.sc.framer.maxHeaderStringLen()
2644
2645
2646
2647
2648
2649
2650 serverConfig.MaxHeaderBytes = 1 << 20
2651
2652
2653
2654
2655
2656 hbf := st.encodeHeader("foo", strings.Repeat("a", maxAllowed))
2657
2658 st.writeHeaders(HeadersFrameParam{
2659 StreamID: 1,
2660 BlockFragment: hbf,
2661 EndStream: true,
2662 EndHeaders: true,
2663 })
2664 st.wantHeaders(wantHeader{
2665 streamID: 1,
2666 endStream: false,
2667 header: http.Header{
2668 ":status": []string{"431"},
2669 "content-type": []string{"text/html; charset=utf-8"},
2670 "content-length": []string{"63"},
2671 },
2672 })
2673 df := readFrame[*DataFrame](t, st)
2674 if !strings.Contains(string(df.Data()), "HTTP Error 431") {
2675 t.Errorf("Unexpected data body: %q", df.Data())
2676 }
2677 if !df.StreamEnded() {
2678 t.Fatalf("expect data stream end")
2679 }
2680
2681
2682 hbf = st.encodeHeader("bar", strings.Repeat("b", maxAllowed+1))
2683 st.writeHeaders(HeadersFrameParam{
2684 StreamID: 3,
2685 BlockFragment: hbf,
2686 EndStream: true,
2687 EndHeaders: true,
2688 })
2689 st.wantGoAway(3, ErrCodeCompression)
2690 }
2691
2692 func TestCompressionErrorOnClose(t *testing.T) {
2693 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
2694
2695 })
2696 st.addLogFilter("connection error: COMPRESSION_ERROR")
2697 defer st.Close()
2698 st.greet()
2699
2700 hbf := st.encodeHeader("foo", "bar")
2701 hbf = hbf[:len(hbf)-1]
2702 st.writeHeaders(HeadersFrameParam{
2703 StreamID: 1,
2704 BlockFragment: hbf,
2705 EndStream: true,
2706 EndHeaders: true,
2707 })
2708 st.wantGoAway(1, ErrCodeCompression)
2709 }
2710
2711
2712 func TestServerReadsTrailers(t *testing.T) {
2713 const testBody = "some test body"
2714 writeReq := func(st *serverTester) {
2715 st.writeHeaders(HeadersFrameParam{
2716 StreamID: 1,
2717 BlockFragment: st.encodeHeader("trailer", "Foo, Bar", "trailer", "Baz"),
2718 EndStream: false,
2719 EndHeaders: true,
2720 })
2721 st.writeData(1, false, []byte(testBody))
2722 st.writeHeaders(HeadersFrameParam{
2723 StreamID: 1,
2724 BlockFragment: st.encodeHeaderRaw(
2725 "foo", "foov",
2726 "bar", "barv",
2727 "baz", "bazv",
2728 "surprise", "wasn't declared; shouldn't show up",
2729 ),
2730 EndStream: true,
2731 EndHeaders: true,
2732 })
2733 }
2734 checkReq := func(r *http.Request) {
2735 wantTrailer := http.Header{
2736 "Foo": nil,
2737 "Bar": nil,
2738 "Baz": nil,
2739 }
2740 if !reflect.DeepEqual(r.Trailer, wantTrailer) {
2741 t.Errorf("initial Trailer = %v; want %v", r.Trailer, wantTrailer)
2742 }
2743 slurp, err := io.ReadAll(r.Body)
2744 if string(slurp) != testBody {
2745 t.Errorf("read body %q; want %q", slurp, testBody)
2746 }
2747 if err != nil {
2748 t.Fatalf("Body slurp: %v", err)
2749 }
2750 wantTrailerAfter := http.Header{
2751 "Foo": {"foov"},
2752 "Bar": {"barv"},
2753 "Baz": {"bazv"},
2754 }
2755 if !reflect.DeepEqual(r.Trailer, wantTrailerAfter) {
2756 t.Errorf("final Trailer = %v; want %v", r.Trailer, wantTrailerAfter)
2757 }
2758 }
2759 testServerRequest(t, writeReq, checkReq)
2760 }
2761
2762
2763 func TestServerWritesTrailers_WithFlush(t *testing.T) { testServerWritesTrailers(t, true) }
2764 func TestServerWritesTrailers_WithoutFlush(t *testing.T) { testServerWritesTrailers(t, false) }
2765
2766 func testServerWritesTrailers(t *testing.T, withFlush bool) {
2767
2768 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2769 w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
2770 w.Header().Add("Trailer", "Server-Trailer-C")
2771 w.Header().Add("Trailer", "Transfer-Encoding, Content-Length, Trailer")
2772
2773
2774 w.Header().Set("Foo", "Bar")
2775 w.Header().Set("Content-Length", "5")
2776
2777 io.WriteString(w, "Hello")
2778 if withFlush {
2779 w.(http.Flusher).Flush()
2780 }
2781 w.Header().Set("Server-Trailer-A", "valuea")
2782 w.Header().Set("Server-Trailer-C", "valuec")
2783
2784 w.Header().Set("Server-Surpise", "surprise! this isn't predeclared!")
2785
2786
2787
2788 w.Header().Set("Trailer:Post-Header-Trailer", "hi1")
2789 w.Header().Set("Trailer:post-header-trailer2", "hi2")
2790 w.Header().Set("Trailer:Range", "invalid")
2791 w.Header().Set("Trailer:Foo\x01Bogus", "invalid")
2792 w.Header().Set("Transfer-Encoding", "should not be included; Forbidden by RFC 7230 4.1.2")
2793 w.Header().Set("Content-Length", "should not be included; Forbidden by RFC 7230 4.1.2")
2794 w.Header().Set("Trailer", "should not be included; Forbidden by RFC 7230 4.1.2")
2795 return nil
2796 }, func(st *serverTester) {
2797 getSlash(st)
2798 st.wantHeaders(wantHeader{
2799 streamID: 1,
2800 endStream: false,
2801 header: http.Header{
2802 ":status": []string{"200"},
2803 "foo": []string{"Bar"},
2804 "trailer": []string{
2805 "Server-Trailer-A, Server-Trailer-B",
2806 "Server-Trailer-C",
2807 "Transfer-Encoding, Content-Length, Trailer",
2808 },
2809 "content-type": []string{"text/plain; charset=utf-8"},
2810 "content-length": []string{"5"},
2811 },
2812 })
2813 st.wantData(wantData{
2814 streamID: 1,
2815 endStream: false,
2816 data: []byte("Hello"),
2817 })
2818 st.wantHeaders(wantHeader{
2819 streamID: 1,
2820 endStream: true,
2821 header: http.Header{
2822 "post-header-trailer": []string{"hi1"},
2823 "post-header-trailer2": []string{"hi2"},
2824 "server-trailer-a": []string{"valuea"},
2825 "server-trailer-c": []string{"valuec"},
2826 },
2827 })
2828 })
2829 }
2830
2831 func TestServerWritesUndeclaredTrailers(t *testing.T) {
2832 const trailer = "Trailer-Header"
2833 const value = "hi1"
2834 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
2835 w.Header().Set(http.TrailerPrefix+trailer, value)
2836 })
2837
2838 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
2839 defer tr.CloseIdleConnections()
2840
2841 cl := &http.Client{Transport: tr}
2842 resp, err := cl.Get(ts.URL)
2843 if err != nil {
2844 t.Fatal(err)
2845 }
2846 io.Copy(io.Discard, resp.Body)
2847 resp.Body.Close()
2848
2849 if got, want := resp.Trailer.Get(trailer), value; got != want {
2850 t.Errorf("trailer %v = %q, want %q", trailer, got, want)
2851 }
2852 }
2853
2854
2855
2856 func TestServerDoesntWriteInvalidHeaders(t *testing.T) {
2857 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
2858 w.Header().Add("OK1", "x")
2859 w.Header().Add("Bad:Colon", "x")
2860 w.Header().Add("Bad1\x00", "x")
2861 w.Header().Add("Bad2", "x\x00y")
2862 return nil
2863 }, func(st *serverTester) {
2864 getSlash(st)
2865 st.wantHeaders(wantHeader{
2866 streamID: 1,
2867 endStream: true,
2868 header: http.Header{
2869 ":status": []string{"200"},
2870 "ok1": []string{"x"},
2871 "content-length": []string{"0"},
2872 },
2873 })
2874 })
2875 }
2876
2877 func BenchmarkServerGets(b *testing.B) {
2878 disableGoroutineTracking(b)
2879 b.ReportAllocs()
2880
2881 const msg = "Hello, world"
2882 st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
2883 io.WriteString(w, msg)
2884 })
2885 defer st.Close()
2886 st.greet()
2887
2888
2889 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
2890 b.Fatal(err)
2891 }
2892
2893 for i := 0; i < b.N; i++ {
2894 id := 1 + uint32(i)*2
2895 st.writeHeaders(HeadersFrameParam{
2896 StreamID: id,
2897 BlockFragment: st.encodeHeader(),
2898 EndStream: true,
2899 EndHeaders: true,
2900 })
2901 st.wantFrameType(FrameHeaders)
2902 if df := readFrame[*DataFrame](b, st); !df.StreamEnded() {
2903 b.Fatalf("DATA didn't have END_STREAM; got %v", df)
2904 }
2905 }
2906 }
2907
2908 func BenchmarkServerPosts(b *testing.B) {
2909 disableGoroutineTracking(b)
2910 b.ReportAllocs()
2911
2912 const msg = "Hello, world"
2913 st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
2914
2915
2916
2917 if n, err := io.Copy(io.Discard, r.Body); n != 0 || err != nil {
2918 b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
2919 }
2920 io.WriteString(w, msg)
2921 })
2922 defer st.Close()
2923 st.greet()
2924
2925
2926 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
2927 b.Fatal(err)
2928 }
2929
2930 for i := 0; i < b.N; i++ {
2931 id := 1 + uint32(i)*2
2932 st.writeHeaders(HeadersFrameParam{
2933 StreamID: id,
2934 BlockFragment: st.encodeHeader(":method", "POST"),
2935 EndStream: false,
2936 EndHeaders: true,
2937 })
2938 st.writeData(id, true, nil)
2939 st.wantFrameType(FrameHeaders)
2940 if df := readFrame[*DataFrame](b, st); !df.StreamEnded() {
2941 b.Fatalf("DATA didn't have END_STREAM; got %v", df)
2942 }
2943 }
2944 }
2945
2946
2947
2948
2949 func BenchmarkServerToClientStreamDefaultOptions(b *testing.B) {
2950 benchmarkServerToClientStream(b)
2951 }
2952
2953
2954
2955 func BenchmarkServerToClientStreamReuseFrames(b *testing.B) {
2956 benchmarkServerToClientStream(b, optFramerReuseFrames)
2957 }
2958
2959 func benchmarkServerToClientStream(b *testing.B, newServerOpts ...interface{}) {
2960 disableGoroutineTracking(b)
2961 b.ReportAllocs()
2962 const msgLen = 1
2963
2964 const windowSize = 1<<16 - 1
2965
2966
2967 nextMsg := func(i int) []byte {
2968 msg := make([]byte, msgLen)
2969 msg[0] = byte(i)
2970 if len(msg) != msgLen {
2971 panic("invalid test setup msg length")
2972 }
2973 return msg
2974 }
2975
2976 st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
2977
2978
2979
2980 if n, err := io.Copy(io.Discard, r.Body); n != 0 || err != nil {
2981 b.Errorf("Copy error; got %v, %v; want 0, nil", n, err)
2982 }
2983 for i := 0; i < b.N; i += 1 {
2984 w.Write(nextMsg(i))
2985 w.(http.Flusher).Flush()
2986 }
2987 }, newServerOpts...)
2988 defer st.Close()
2989 st.greet()
2990
2991 const id = uint32(1)
2992
2993 st.writeHeaders(HeadersFrameParam{
2994 StreamID: id,
2995 BlockFragment: st.encodeHeader(":method", "POST"),
2996 EndStream: false,
2997 EndHeaders: true,
2998 })
2999
3000 st.writeData(id, true, nil)
3001 st.wantHeaders(wantHeader{
3002 streamID: 1,
3003 endStream: false,
3004 })
3005
3006 var pendingWindowUpdate = uint32(0)
3007
3008 for i := 0; i < b.N; i += 1 {
3009 expected := nextMsg(i)
3010 st.wantData(wantData{
3011 streamID: 1,
3012 endStream: false,
3013 data: expected,
3014 })
3015
3016 pendingWindowUpdate += uint32(len(expected))
3017 if pendingWindowUpdate >= windowSize/2 {
3018 if err := st.fr.WriteWindowUpdate(0, pendingWindowUpdate); err != nil {
3019 b.Fatal(err)
3020 }
3021 if err := st.fr.WriteWindowUpdate(id, pendingWindowUpdate); err != nil {
3022 b.Fatal(err)
3023 }
3024 pendingWindowUpdate = 0
3025 }
3026 }
3027 st.wantData(wantData{
3028 streamID: 1,
3029 endStream: true,
3030 })
3031 }
3032
3033
3034
3035 func TestIssue53(t *testing.T) {
3036 const data = "PRI * HTTP/2.0\r\n\r\nSM" +
3037 "\r\n\r\n\x00\x00\x00\x01\ainfinfin\ad"
3038 s := &http.Server{
3039 ErrorLog: log.New(io.MultiWriter(stderrv(), twriter{t: t}), "", log.LstdFlags),
3040 Handler: http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
3041 w.Write([]byte("hello"))
3042 }),
3043 }
3044 s2 := &Server{
3045 MaxReadFrameSize: 1 << 16,
3046 PermitProhibitedCipherSuites: true,
3047 }
3048 c := &issue53Conn{[]byte(data), false, false}
3049 s2.ServeConn(c, &ServeConnOpts{BaseConfig: s})
3050 if !c.closed {
3051 t.Fatal("connection is not closed")
3052 }
3053 }
3054
3055 type issue53Conn struct {
3056 data []byte
3057 closed bool
3058 written bool
3059 }
3060
3061 func (c *issue53Conn) Read(b []byte) (n int, err error) {
3062 if len(c.data) == 0 {
3063 return 0, io.EOF
3064 }
3065 n = copy(b, c.data)
3066 c.data = c.data[n:]
3067 return
3068 }
3069
3070 func (c *issue53Conn) Write(b []byte) (n int, err error) {
3071 c.written = true
3072 return len(b), nil
3073 }
3074
3075 func (c *issue53Conn) Close() error {
3076 c.closed = true
3077 return nil
3078 }
3079
3080 func (c *issue53Conn) LocalAddr() net.Addr {
3081 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3082 }
3083 func (c *issue53Conn) RemoteAddr() net.Addr {
3084 return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 49706}
3085 }
3086 func (c *issue53Conn) SetDeadline(t time.Time) error { return nil }
3087 func (c *issue53Conn) SetReadDeadline(t time.Time) error { return nil }
3088 func (c *issue53Conn) SetWriteDeadline(t time.Time) error { return nil }
3089
3090
3091 func TestServeConnOptsNilReceiverBehavior(t *testing.T) {
3092 defer func() {
3093 if r := recover(); r != nil {
3094 t.Errorf("got a panic that should not happen: %v", r)
3095 }
3096 }()
3097
3098 var o *ServeConnOpts
3099 if o.context() == nil {
3100 t.Error("o.context should not return nil")
3101 }
3102 if o.baseConfig() == nil {
3103 t.Error("o.baseConfig should not return nil")
3104 }
3105 if o.handler() == nil {
3106 t.Error("o.handler should not return nil")
3107 }
3108 }
3109
3110
3111 func TestConfigureServer(t *testing.T) {
3112 tests := []struct {
3113 name string
3114 tlsConfig *tls.Config
3115 wantErr string
3116 }{
3117 {
3118 name: "empty server",
3119 },
3120 {
3121 name: "empty CipherSuites",
3122 tlsConfig: &tls.Config{},
3123 },
3124 {
3125 name: "bad CipherSuites but MinVersion TLS 1.3",
3126 tlsConfig: &tls.Config{
3127 MinVersion: tls.VersionTLS13,
3128 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3129 },
3130 },
3131 {
3132 name: "just the required cipher suite",
3133 tlsConfig: &tls.Config{
3134 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3135 },
3136 },
3137 {
3138 name: "just the alternative required cipher suite",
3139 tlsConfig: &tls.Config{
3140 CipherSuites: []uint16{tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256},
3141 },
3142 },
3143 {
3144 name: "missing required cipher suite",
3145 tlsConfig: &tls.Config{
3146 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384},
3147 },
3148 wantErr: "is missing an HTTP/2-required",
3149 },
3150 {
3151 name: "required after bad",
3152 tlsConfig: &tls.Config{
3153 CipherSuites: []uint16{tls.TLS_RSA_WITH_RC4_128_SHA, tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
3154 },
3155 },
3156 {
3157 name: "bad after required",
3158 tlsConfig: &tls.Config{
3159 CipherSuites: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_RSA_WITH_RC4_128_SHA},
3160 },
3161 },
3162 }
3163 for _, tt := range tests {
3164 srv := &http.Server{TLSConfig: tt.tlsConfig}
3165 err := ConfigureServer(srv, nil)
3166 if (err != nil) != (tt.wantErr != "") {
3167 if tt.wantErr != "" {
3168 t.Errorf("%s: success, but want error", tt.name)
3169 } else {
3170 t.Errorf("%s: unexpected error: %v", tt.name, err)
3171 }
3172 }
3173 if err != nil && tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) {
3174 t.Errorf("%s: err = %v; want substring %q", tt.name, err, tt.wantErr)
3175 }
3176 if err == nil && !srv.TLSConfig.PreferServerCipherSuites {
3177 t.Errorf("%s: PreferServerCipherSuite is false; want true", tt.name)
3178 }
3179 }
3180 }
3181
3182 func TestServerNoAutoContentLengthOnHead(t *testing.T) {
3183 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3184
3185 })
3186 defer st.Close()
3187 st.greet()
3188 st.writeHeaders(HeadersFrameParam{
3189 StreamID: 1,
3190 BlockFragment: st.encodeHeader(":method", "HEAD"),
3191 EndStream: true,
3192 EndHeaders: true,
3193 })
3194 st.wantHeaders(wantHeader{
3195 streamID: 1,
3196 endStream: true,
3197 header: http.Header{
3198 ":status": []string{"200"},
3199 },
3200 })
3201 }
3202
3203
3204 func TestServerNoDuplicateContentType(t *testing.T) {
3205 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3206 w.Header()["Content-Type"] = []string{""}
3207 fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
3208 })
3209 defer st.Close()
3210 st.greet()
3211 st.writeHeaders(HeadersFrameParam{
3212 StreamID: 1,
3213 BlockFragment: st.encodeHeader(),
3214 EndStream: true,
3215 EndHeaders: true,
3216 })
3217 st.wantHeaders(wantHeader{
3218 streamID: 1,
3219 endStream: false,
3220 header: http.Header{
3221 ":status": []string{"200"},
3222 "content-type": []string{""},
3223 "content-length": []string{"41"},
3224 },
3225 })
3226 }
3227
3228 func TestServerContentLengthCanBeDisabled(t *testing.T) {
3229 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3230 w.Header()["Content-Length"] = nil
3231 fmt.Fprintf(w, "OK")
3232 })
3233 defer st.Close()
3234 st.greet()
3235 st.writeHeaders(HeadersFrameParam{
3236 StreamID: 1,
3237 BlockFragment: st.encodeHeader(),
3238 EndStream: true,
3239 EndHeaders: true,
3240 })
3241 st.wantHeaders(wantHeader{
3242 streamID: 1,
3243 endStream: false,
3244 header: http.Header{
3245 ":status": []string{"200"},
3246 "content-type": []string{"text/plain; charset=utf-8"},
3247 },
3248 })
3249 }
3250
3251 func disableGoroutineTracking(t testing.TB) {
3252 old := DebugGoroutines
3253 DebugGoroutines = false
3254 t.Cleanup(func() { DebugGoroutines = old })
3255 }
3256
3257 func BenchmarkServer_GetRequest(b *testing.B) {
3258 disableGoroutineTracking(b)
3259 b.ReportAllocs()
3260 const msg = "Hello, world."
3261 st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
3262 n, err := io.Copy(io.Discard, r.Body)
3263 if err != nil || n > 0 {
3264 b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3265 }
3266 io.WriteString(w, msg)
3267 })
3268 defer st.Close()
3269
3270 st.greet()
3271
3272 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3273 b.Fatal(err)
3274 }
3275 hbf := st.encodeHeader(":method", "GET")
3276 for i := 0; i < b.N; i++ {
3277 streamID := uint32(1 + 2*i)
3278 st.writeHeaders(HeadersFrameParam{
3279 StreamID: streamID,
3280 BlockFragment: hbf,
3281 EndStream: true,
3282 EndHeaders: true,
3283 })
3284 st.wantFrameType(FrameHeaders)
3285 st.wantFrameType(FrameData)
3286 }
3287 }
3288
3289 func BenchmarkServer_PostRequest(b *testing.B) {
3290 disableGoroutineTracking(b)
3291 b.ReportAllocs()
3292 const msg = "Hello, world."
3293 st := newServerTesterWithRealConn(b, func(w http.ResponseWriter, r *http.Request) {
3294 n, err := io.Copy(io.Discard, r.Body)
3295 if err != nil || n > 0 {
3296 b.Errorf("Read %d bytes, error %v; want 0 bytes.", n, err)
3297 }
3298 io.WriteString(w, msg)
3299 })
3300 defer st.Close()
3301 st.greet()
3302
3303 if err := st.fr.WriteWindowUpdate(0, uint32(b.N*len(msg))); err != nil {
3304 b.Fatal(err)
3305 }
3306 hbf := st.encodeHeader(":method", "POST")
3307 for i := 0; i < b.N; i++ {
3308 streamID := uint32(1 + 2*i)
3309 st.writeHeaders(HeadersFrameParam{
3310 StreamID: streamID,
3311 BlockFragment: hbf,
3312 EndStream: false,
3313 EndHeaders: true,
3314 })
3315 st.writeData(streamID, true, nil)
3316 st.wantFrameType(FrameHeaders)
3317 st.wantFrameType(FrameData)
3318 }
3319 }
3320
3321 type connStateConn struct {
3322 net.Conn
3323 cs tls.ConnectionState
3324 }
3325
3326 func (c connStateConn) ConnectionState() tls.ConnectionState { return c.cs }
3327
3328
3329
3330 func TestServerHandleCustomConn(t *testing.T) {
3331 var s Server
3332 c1, c2 := net.Pipe()
3333 clientDone := make(chan struct{})
3334 handlerDone := make(chan struct{})
3335 var req *http.Request
3336 go func() {
3337 defer close(clientDone)
3338 defer c2.Close()
3339 fr := NewFramer(c2, c2)
3340 io.WriteString(c2, ClientPreface)
3341 fr.WriteSettings()
3342 fr.WriteSettingsAck()
3343 f, err := fr.ReadFrame()
3344 if err != nil {
3345 t.Error(err)
3346 return
3347 }
3348 if sf, ok := f.(*SettingsFrame); !ok || sf.IsAck() {
3349 t.Errorf("Got %v; want non-ACK SettingsFrame", summarizeFrame(f))
3350 return
3351 }
3352 f, err = fr.ReadFrame()
3353 if err != nil {
3354 t.Error(err)
3355 return
3356 }
3357 if sf, ok := f.(*SettingsFrame); !ok || !sf.IsAck() {
3358 t.Errorf("Got %v; want ACK SettingsFrame", summarizeFrame(f))
3359 return
3360 }
3361 var henc hpackEncoder
3362 fr.WriteHeaders(HeadersFrameParam{
3363 StreamID: 1,
3364 BlockFragment: henc.encodeHeaderRaw(t, ":method", "GET", ":path", "/", ":scheme", "https", ":authority", "foo.com"),
3365 EndStream: true,
3366 EndHeaders: true,
3367 })
3368 go io.Copy(io.Discard, c2)
3369 <-handlerDone
3370 }()
3371 const testString = "my custom ConnectionState"
3372 fakeConnState := tls.ConnectionState{
3373 ServerName: testString,
3374 Version: tls.VersionTLS12,
3375 CipherSuite: cipher_TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
3376 }
3377 go s.ServeConn(connStateConn{c1, fakeConnState}, &ServeConnOpts{
3378 BaseConfig: &http.Server{
3379 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3380 defer close(handlerDone)
3381 req = r
3382 }),
3383 }})
3384 <-clientDone
3385
3386 if req.TLS == nil {
3387 t.Fatalf("Request.TLS is nil. Got: %#v", req)
3388 }
3389 if req.TLS.ServerName != testString {
3390 t.Fatalf("Request.TLS = %+v; want ServerName of %q", req.TLS, testString)
3391 }
3392 }
3393
3394
3395 func TestServer_Rejects_ConnHeaders(t *testing.T) {
3396 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3397 t.Error("should not get to Handler")
3398 })
3399 defer st.Close()
3400 st.greet()
3401 st.bodylessReq1("connection", "foo")
3402 st.wantHeaders(wantHeader{
3403 streamID: 1,
3404 endStream: false,
3405 header: http.Header{
3406 ":status": []string{"400"},
3407 "content-type": []string{"text/plain; charset=utf-8"},
3408 "x-content-type-options": []string{"nosniff"},
3409 "content-length": []string{"51"},
3410 },
3411 })
3412 }
3413
3414 type hpackEncoder struct {
3415 enc *hpack.Encoder
3416 buf bytes.Buffer
3417 }
3418
3419 func (he *hpackEncoder) encodeHeaderRaw(t *testing.T, headers ...string) []byte {
3420 if len(headers)%2 == 1 {
3421 panic("odd number of kv args")
3422 }
3423 he.buf.Reset()
3424 if he.enc == nil {
3425 he.enc = hpack.NewEncoder(&he.buf)
3426 }
3427 for len(headers) > 0 {
3428 k, v := headers[0], headers[1]
3429 err := he.enc.WriteField(hpack.HeaderField{Name: k, Value: v})
3430 if err != nil {
3431 t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
3432 }
3433 headers = headers[2:]
3434 }
3435 return he.buf.Bytes()
3436 }
3437
3438 func TestCheckValidHTTP2Request(t *testing.T) {
3439 tests := []struct {
3440 h http.Header
3441 want error
3442 }{
3443 {
3444 h: http.Header{"Te": {"trailers"}},
3445 want: nil,
3446 },
3447 {
3448 h: http.Header{"Te": {"trailers", "bogus"}},
3449 want: errors.New(`request header "TE" may only be "trailers" in HTTP/2`),
3450 },
3451 {
3452 h: http.Header{"Foo": {""}},
3453 want: nil,
3454 },
3455 {
3456 h: http.Header{"Connection": {""}},
3457 want: errors.New(`request header "Connection" is not valid in HTTP/2`),
3458 },
3459 {
3460 h: http.Header{"Proxy-Connection": {""}},
3461 want: errors.New(`request header "Proxy-Connection" is not valid in HTTP/2`),
3462 },
3463 {
3464 h: http.Header{"Keep-Alive": {""}},
3465 want: errors.New(`request header "Keep-Alive" is not valid in HTTP/2`),
3466 },
3467 {
3468 h: http.Header{"Upgrade": {""}},
3469 want: errors.New(`request header "Upgrade" is not valid in HTTP/2`),
3470 },
3471 }
3472 for i, tt := range tests {
3473 got := checkValidHTTP2RequestHeaders(tt.h)
3474 if !equalError(got, tt.want) {
3475 t.Errorf("%d. checkValidHTTP2Request = %v; want %v", i, got, tt.want)
3476 }
3477 }
3478 }
3479
3480
3481 func TestExpect100ContinueAfterHandlerWrites(t *testing.T) {
3482 const msg = "Hello"
3483 const msg2 = "World"
3484
3485 doRead := make(chan bool, 1)
3486 defer close(doRead)
3487
3488 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3489 io.WriteString(w, msg)
3490 w.(http.Flusher).Flush()
3491
3492
3493 <-doRead
3494 r.Body.Read(make([]byte, 10))
3495
3496 io.WriteString(w, msg2)
3497 })
3498
3499 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3500 defer tr.CloseIdleConnections()
3501
3502 req, _ := http.NewRequest("POST", ts.URL, io.LimitReader(neverEnding('A'), 2<<20))
3503 req.Header.Set("Expect", "100-continue")
3504
3505 res, err := tr.RoundTrip(req)
3506 if err != nil {
3507 t.Fatal(err)
3508 }
3509 defer res.Body.Close()
3510
3511 buf := make([]byte, len(msg))
3512 if _, err := io.ReadFull(res.Body, buf); err != nil {
3513 t.Fatal(err)
3514 }
3515 if string(buf) != msg {
3516 t.Fatalf("msg = %q; want %q", buf, msg)
3517 }
3518
3519 doRead <- true
3520
3521 if _, err := io.ReadFull(res.Body, buf); err != nil {
3522 t.Fatal(err)
3523 }
3524 if string(buf) != msg2 {
3525 t.Fatalf("second msg = %q; want %q", buf, msg2)
3526 }
3527 }
3528
3529 type funcReader func([]byte) (n int, err error)
3530
3531 func (f funcReader) Read(p []byte) (n int, err error) { return f(p) }
3532
3533
3534
3535 func TestUnreadFlowControlReturned_Server(t *testing.T) {
3536 for _, tt := range []struct {
3537 name string
3538 reqFn func(r *http.Request)
3539 }{
3540 {
3541 "body-open",
3542 func(r *http.Request) {},
3543 },
3544 {
3545 "body-closed",
3546 func(r *http.Request) {
3547 r.Body.Close()
3548 },
3549 },
3550 {
3551 "read-1-byte-and-close",
3552 func(r *http.Request) {
3553 b := make([]byte, 1)
3554 r.Body.Read(b)
3555 r.Body.Close()
3556 },
3557 },
3558 } {
3559 t.Run(tt.name, func(t *testing.T) {
3560 unblock := make(chan bool, 1)
3561 defer close(unblock)
3562
3563 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3564
3565
3566
3567 tt.reqFn(r)
3568 <-unblock
3569 })
3570
3571 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3572 defer tr.CloseIdleConnections()
3573
3574
3575 iters := 100
3576 if testing.Short() {
3577 iters = 20
3578 }
3579 for i := 0; i < iters; i++ {
3580 body := io.MultiReader(
3581 io.LimitReader(neverEnding('A'), 16<<10),
3582 funcReader(func([]byte) (n int, err error) {
3583 unblock <- true
3584 return 0, io.EOF
3585 }),
3586 )
3587 req, _ := http.NewRequest("POST", ts.URL, body)
3588 res, err := tr.RoundTrip(req)
3589 if err != nil {
3590 t.Fatal(tt.name, err)
3591 }
3592 res.Body.Close()
3593 }
3594 })
3595 }
3596 }
3597
3598 func TestServerReturnsStreamAndConnFlowControlOnBodyClose(t *testing.T) {
3599 unblockHandler := make(chan struct{})
3600 defer close(unblockHandler)
3601
3602 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3603 r.Body.Close()
3604 w.WriteHeader(200)
3605 w.(http.Flusher).Flush()
3606 <-unblockHandler
3607 })
3608 defer st.Close()
3609
3610 st.greet()
3611 st.writeHeaders(HeadersFrameParam{
3612 StreamID: 1,
3613 BlockFragment: st.encodeHeader(),
3614 EndHeaders: true,
3615 })
3616 st.wantHeaders(wantHeader{
3617 streamID: 1,
3618 endStream: false,
3619 })
3620 const size = inflowMinRefresh
3621 st.writeData(1, false, make([]byte, size))
3622 st.wantWindowUpdate(0, size)
3623 unblockHandler <- struct{}{}
3624 st.wantData(wantData{
3625 streamID: 1,
3626 endStream: true,
3627 })
3628 }
3629
3630 func TestServerIdleTimeout(t *testing.T) {
3631 if testing.Short() {
3632 t.Skip("skipping in short mode")
3633 }
3634
3635 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3636 }, func(h2s *Server) {
3637 h2s.IdleTimeout = 500 * time.Millisecond
3638 })
3639 defer st.Close()
3640
3641 st.greet()
3642 st.advance(500 * time.Millisecond)
3643 st.wantGoAway(0, ErrCodeNo)
3644 }
3645
3646 func TestServerIdleTimeout_AfterRequest(t *testing.T) {
3647 if testing.Short() {
3648 t.Skip("skipping in short mode")
3649 }
3650 const (
3651 requestTimeout = 2 * time.Second
3652 idleTimeout = 1 * time.Second
3653 )
3654
3655 var st *serverTester
3656 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3657 st.group.Sleep(requestTimeout)
3658 }, func(h2s *Server) {
3659 h2s.IdleTimeout = idleTimeout
3660 })
3661 defer st.Close()
3662
3663 st.greet()
3664
3665
3666
3667 st.bodylessReq1()
3668 st.advance(requestTimeout)
3669 st.wantHeaders(wantHeader{
3670 streamID: 1,
3671 endStream: true,
3672 })
3673
3674
3675
3676 st.advance(idleTimeout)
3677 st.wantGoAway(1, ErrCodeNo)
3678 }
3679
3680
3681
3682
3683 func TestRequestBodyReadCloseRace(t *testing.T) {
3684 for i := 0; i < 100; i++ {
3685 body := &requestBody{
3686 pipe: &pipe{
3687 b: new(bytes.Buffer),
3688 },
3689 }
3690 body.pipe.CloseWithError(io.EOF)
3691
3692 done := make(chan bool, 1)
3693 buf := make([]byte, 10)
3694 go func() {
3695 time.Sleep(1 * time.Millisecond)
3696 body.Close()
3697 done <- true
3698 }()
3699 body.Read(buf)
3700 <-done
3701 }
3702 }
3703
3704 func TestIssue20704Race(t *testing.T) {
3705 if testing.Short() && os.Getenv("GO_BUILDER_NAME") == "" {
3706 t.Skip("skipping in short mode")
3707 }
3708 const (
3709 itemSize = 1 << 10
3710 itemCount = 100
3711 )
3712
3713 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3714 for i := 0; i < itemCount; i++ {
3715 _, err := w.Write(make([]byte, itemSize))
3716 if err != nil {
3717 return
3718 }
3719 }
3720 })
3721
3722 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
3723 defer tr.CloseIdleConnections()
3724 cl := &http.Client{Transport: tr}
3725
3726 for i := 0; i < 1000; i++ {
3727 resp, err := cl.Get(ts.URL)
3728 if err != nil {
3729 t.Fatal(err)
3730 }
3731
3732
3733 resp.Body.Close()
3734 }
3735 }
3736
3737 func TestServer_Rejects_TooSmall(t *testing.T) {
3738 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3739 io.ReadAll(r.Body)
3740 return nil
3741 }, func(st *serverTester) {
3742 st.writeHeaders(HeadersFrameParam{
3743 StreamID: 1,
3744 BlockFragment: st.encodeHeader(
3745 ":method", "POST",
3746 "content-length", "4",
3747 ),
3748 EndStream: false,
3749 EndHeaders: true,
3750 })
3751 st.writeData(1, true, []byte("12345"))
3752 st.wantRSTStream(1, ErrCodeProtocol)
3753 st.wantFlowControlConsumed(0, 0)
3754 })
3755 }
3756
3757
3758
3759 func TestServerHandlerConnectionClose(t *testing.T) {
3760 unblockHandler := make(chan bool, 1)
3761 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
3762 w.Header().Set("Connection", "close")
3763 w.Header().Set("Foo", "bar")
3764 w.(http.Flusher).Flush()
3765 <-unblockHandler
3766 return nil
3767 }, func(st *serverTester) {
3768 defer close(unblockHandler)
3769 st.writeHeaders(HeadersFrameParam{
3770 StreamID: 1,
3771 BlockFragment: st.encodeHeader(),
3772 EndStream: true,
3773 EndHeaders: true,
3774 })
3775 var sawGoAway bool
3776 var sawRes bool
3777 var sawWindowUpdate bool
3778 for {
3779 f := st.readFrame()
3780 if f == nil {
3781 break
3782 }
3783 switch f := f.(type) {
3784 case *GoAwayFrame:
3785 sawGoAway = true
3786 if f.LastStreamID != 1 || f.ErrCode != ErrCodeNo {
3787 t.Errorf("unexpected GOAWAY frame: %v", summarizeFrame(f))
3788 }
3789
3790
3791 st.writeHeaders(HeadersFrameParam{
3792 StreamID: 3,
3793 BlockFragment: st.encodeHeader(),
3794 EndStream: false,
3795 EndHeaders: true,
3796 })
3797 st.fr.WriteRSTStream(3, ErrCodeCancel)
3798
3799
3800
3801 st.writeHeaders(HeadersFrameParam{
3802 StreamID: 5,
3803 BlockFragment: st.encodeHeader(),
3804 EndStream: false,
3805 EndHeaders: true,
3806 })
3807
3808 st.writeData(5, true, make([]byte, 1<<19))
3809 case *HeadersFrame:
3810 goth := st.decodeHeader(f.HeaderBlockFragment())
3811 wanth := [][2]string{
3812 {":status", "200"},
3813 {"foo", "bar"},
3814 }
3815 if !reflect.DeepEqual(goth, wanth) {
3816 t.Errorf("got headers %v; want %v", goth, wanth)
3817 }
3818 sawRes = true
3819 case *DataFrame:
3820 if f.StreamID != 1 || !f.StreamEnded() || len(f.Data()) != 0 {
3821 t.Errorf("unexpected DATA frame: %v", summarizeFrame(f))
3822 }
3823 case *WindowUpdateFrame:
3824 if !sawGoAway {
3825 t.Errorf("unexpected WINDOW_UPDATE frame: %v", summarizeFrame(f))
3826 return
3827 }
3828 if f.StreamID != 0 {
3829 st.t.Fatalf("WindowUpdate StreamID = %d; want 5", f.FrameHeader.StreamID)
3830 return
3831 }
3832 sawWindowUpdate = true
3833 unblockHandler <- true
3834 st.sync()
3835 st.advance(goAwayTimeout)
3836 default:
3837 t.Logf("unexpected frame: %v", summarizeFrame(f))
3838 }
3839 }
3840 if !sawGoAway {
3841 t.Errorf("didn't see GOAWAY")
3842 }
3843 if !sawRes {
3844 t.Errorf("didn't see response")
3845 }
3846 if !sawWindowUpdate {
3847 t.Errorf("didn't see WINDOW_UPDATE")
3848 }
3849 })
3850 }
3851
3852 func TestServer_Headers_HalfCloseRemote(t *testing.T) {
3853 var st *serverTester
3854 writeData := make(chan bool)
3855 writeHeaders := make(chan bool)
3856 leaveHandler := make(chan bool)
3857 st = newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3858 if st.stream(1) == nil {
3859 t.Errorf("nil stream 1 in handler")
3860 }
3861 if got, want := st.streamState(1), stateOpen; got != want {
3862 t.Errorf("in handler, state is %v; want %v", got, want)
3863 }
3864 writeData <- true
3865 if n, err := r.Body.Read(make([]byte, 1)); n != 0 || err != io.EOF {
3866 t.Errorf("body read = %d, %v; want 0, EOF", n, err)
3867 }
3868 if got, want := st.streamState(1), stateHalfClosedRemote; got != want {
3869 t.Errorf("in handler, state is %v; want %v", got, want)
3870 }
3871 writeHeaders <- true
3872
3873 <-leaveHandler
3874 })
3875 st.greet()
3876
3877 st.writeHeaders(HeadersFrameParam{
3878 StreamID: 1,
3879 BlockFragment: st.encodeHeader(),
3880 EndStream: false,
3881 EndHeaders: true,
3882 })
3883 <-writeData
3884 st.writeData(1, true, nil)
3885
3886 <-writeHeaders
3887
3888 st.writeHeaders(HeadersFrameParam{
3889 StreamID: 1,
3890 BlockFragment: st.encodeHeader(),
3891 EndStream: false,
3892 EndHeaders: true,
3893 })
3894
3895 defer close(leaveHandler)
3896
3897 st.wantRSTStream(1, ErrCodeStreamClosed)
3898 }
3899
3900 func TestServerGracefulShutdown(t *testing.T) {
3901 handlerDone := make(chan struct{})
3902 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
3903 <-handlerDone
3904 w.Header().Set("x-foo", "bar")
3905 })
3906 defer st.Close()
3907
3908 st.greet()
3909 st.bodylessReq1()
3910
3911 st.sync()
3912 st.h1server.Shutdown(context.Background())
3913
3914 st.wantGoAway(1, ErrCodeNo)
3915
3916 close(handlerDone)
3917 st.sync()
3918
3919 st.wantHeaders(wantHeader{
3920 streamID: 1,
3921 endStream: true,
3922 header: http.Header{
3923 ":status": []string{"200"},
3924 "x-foo": []string{"bar"},
3925 "content-length": []string{"0"},
3926 },
3927 })
3928
3929 n, err := st.cc.Read([]byte{0})
3930 if n != 0 || err == nil {
3931 t.Errorf("Read = %v, %v; want 0, non-nil", n, err)
3932 }
3933 }
3934
3935
3936 func TestContentEncodingNoSniffing(t *testing.T) {
3937 type resp struct {
3938 name string
3939 body []byte
3940
3941
3942
3943 contentEncoding interface{}
3944 wantContentType string
3945 }
3946
3947 resps := []*resp{
3948 {
3949 name: "gzip content-encoding, gzipped",
3950 contentEncoding: "application/gzip",
3951 wantContentType: "",
3952 body: func() []byte {
3953 buf := new(bytes.Buffer)
3954 gzw := gzip.NewWriter(buf)
3955 gzw.Write([]byte("doctype html><p>Hello</p>"))
3956 gzw.Close()
3957 return buf.Bytes()
3958 }(),
3959 },
3960 {
3961 name: "zlib content-encoding, zlibbed",
3962 contentEncoding: "application/zlib",
3963 wantContentType: "",
3964 body: func() []byte {
3965 buf := new(bytes.Buffer)
3966 zw := zlib.NewWriter(buf)
3967 zw.Write([]byte("doctype html><p>Hello</p>"))
3968 zw.Close()
3969 return buf.Bytes()
3970 }(),
3971 },
3972 {
3973 name: "no content-encoding",
3974 wantContentType: "application/x-gzip",
3975 body: func() []byte {
3976 buf := new(bytes.Buffer)
3977 gzw := gzip.NewWriter(buf)
3978 gzw.Write([]byte("doctype html><p>Hello</p>"))
3979 gzw.Close()
3980 return buf.Bytes()
3981 }(),
3982 },
3983 {
3984 name: "phony content-encoding",
3985 contentEncoding: "foo/bar",
3986 body: []byte("doctype html><p>Hello</p>"),
3987 },
3988 {
3989 name: "empty but set content-encoding",
3990 contentEncoding: "",
3991 wantContentType: "audio/mpeg",
3992 body: []byte("ID3"),
3993 },
3994 }
3995
3996 for _, tt := range resps {
3997 t.Run(tt.name, func(t *testing.T) {
3998 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
3999 if tt.contentEncoding != nil {
4000 w.Header().Set("Content-Encoding", tt.contentEncoding.(string))
4001 }
4002 w.Write(tt.body)
4003 })
4004
4005 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4006 defer tr.CloseIdleConnections()
4007
4008 req, _ := http.NewRequest("GET", ts.URL, nil)
4009 res, err := tr.RoundTrip(req)
4010 if err != nil {
4011 t.Fatalf("GET %s: %v", ts.URL, err)
4012 }
4013 defer res.Body.Close()
4014
4015 g := res.Header.Get("Content-Encoding")
4016 t.Logf("%s: Content-Encoding: %s", ts.URL, g)
4017
4018 if w := tt.contentEncoding; g != w {
4019 if w != nil {
4020 t.Errorf("Content-Encoding mismatch\n\tgot: %q\n\twant: %q", g, w)
4021 } else if g != "" {
4022 t.Errorf("Unexpected Content-Encoding %q", g)
4023 }
4024 }
4025
4026 g = res.Header.Get("Content-Type")
4027 if w := tt.wantContentType; g != w {
4028 t.Errorf("Content-Type mismatch\n\tgot: %q\n\twant: %q", g, w)
4029 }
4030 t.Logf("%s: Content-Type: %s", ts.URL, g)
4031 })
4032 }
4033 }
4034
4035 func TestServerWindowUpdateOnBodyClose(t *testing.T) {
4036 const windowSize = 65535 * 2
4037 content := make([]byte, windowSize)
4038 blockCh := make(chan bool)
4039 errc := make(chan error, 1)
4040 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4041 buf := make([]byte, 4)
4042 n, err := io.ReadFull(r.Body, buf)
4043 if err != nil {
4044 errc <- err
4045 return
4046 }
4047 if n != len(buf) {
4048 errc <- fmt.Errorf("too few bytes read: %d", n)
4049 return
4050 }
4051 blockCh <- true
4052 <-blockCh
4053 errc <- nil
4054 }, func(s *Server) {
4055 s.MaxUploadBufferPerConnection = windowSize
4056 s.MaxUploadBufferPerStream = windowSize
4057 })
4058 defer st.Close()
4059
4060 st.greet()
4061 st.writeHeaders(HeadersFrameParam{
4062 StreamID: 1,
4063 BlockFragment: st.encodeHeader(
4064 ":method", "POST",
4065 "content-length", strconv.Itoa(len(content)),
4066 ),
4067 EndStream: false,
4068 EndHeaders: true,
4069 })
4070 st.writeData(1, false, content[:windowSize/2])
4071 <-blockCh
4072 st.stream(1).body.CloseWithError(io.EOF)
4073 blockCh <- true
4074
4075
4076 increments := windowSize / 2
4077 for {
4078 f := st.readFrame()
4079 if f == nil {
4080 break
4081 }
4082 if wu, ok := f.(*WindowUpdateFrame); ok && wu.StreamID == 0 {
4083 increments -= int(wu.Increment)
4084 if increments == 0 {
4085 break
4086 }
4087 }
4088 }
4089
4090
4091 st.writeData(1, false, content[windowSize/2:])
4092 st.wantWindowUpdate(0, windowSize/2)
4093
4094 if err := <-errc; err != nil {
4095 t.Error(err)
4096 }
4097 }
4098
4099 func TestNoErrorLoggedOnPostAfterGOAWAY(t *testing.T) {
4100 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {})
4101 defer st.Close()
4102
4103 st.greet()
4104
4105 content := "some content"
4106 st.writeHeaders(HeadersFrameParam{
4107 StreamID: 1,
4108 BlockFragment: st.encodeHeader(
4109 ":method", "POST",
4110 "content-length", strconv.Itoa(len(content)),
4111 ),
4112 EndStream: false,
4113 EndHeaders: true,
4114 })
4115 st.wantHeaders(wantHeader{
4116 streamID: 1,
4117 endStream: true,
4118 })
4119
4120 st.sc.startGracefulShutdown()
4121 st.wantRSTStream(1, ErrCodeNo)
4122 st.wantGoAway(1, ErrCodeNo)
4123
4124 st.writeData(1, true, []byte(content))
4125 st.Close()
4126
4127 if bytes.Contains(st.serverLogBuf.Bytes(), []byte("PROTOCOL_ERROR")) {
4128 t.Error("got protocol error")
4129 }
4130 }
4131
4132 func TestServerSendsProcessing(t *testing.T) {
4133 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
4134 w.WriteHeader(http.StatusProcessing)
4135 w.Write([]byte("stuff"))
4136
4137 return nil
4138 }, func(st *serverTester) {
4139 getSlash(st)
4140 st.wantHeaders(wantHeader{
4141 streamID: 1,
4142 endStream: false,
4143 header: http.Header{
4144 ":status": []string{"102"},
4145 },
4146 })
4147 st.wantHeaders(wantHeader{
4148 streamID: 1,
4149 endStream: false,
4150 header: http.Header{
4151 ":status": []string{"200"},
4152 "content-type": []string{"text/plain; charset=utf-8"},
4153 "content-length": []string{"5"},
4154 },
4155 })
4156 })
4157 }
4158
4159 func TestServerSendsEarlyHints(t *testing.T) {
4160 testServerResponse(t, func(w http.ResponseWriter, r *http.Request) error {
4161 h := w.Header()
4162 h.Add("Content-Length", "123")
4163 h.Add("Link", "</style.css>; rel=preload; as=style")
4164 h.Add("Link", "</script.js>; rel=preload; as=script")
4165 w.WriteHeader(http.StatusEarlyHints)
4166
4167 h.Add("Link", "</foo.js>; rel=preload; as=script")
4168 w.WriteHeader(http.StatusEarlyHints)
4169
4170 w.Write([]byte("stuff"))
4171
4172 return nil
4173 }, func(st *serverTester) {
4174 getSlash(st)
4175 st.wantHeaders(wantHeader{
4176 streamID: 1,
4177 endStream: false,
4178 header: http.Header{
4179 ":status": []string{"103"},
4180 "link": []string{
4181 "</style.css>; rel=preload; as=style",
4182 "</script.js>; rel=preload; as=script",
4183 },
4184 },
4185 })
4186 st.wantHeaders(wantHeader{
4187 streamID: 1,
4188 endStream: false,
4189 header: http.Header{
4190 ":status": []string{"103"},
4191 "link": []string{
4192 "</style.css>; rel=preload; as=style",
4193 "</script.js>; rel=preload; as=script",
4194 "</foo.js>; rel=preload; as=script",
4195 },
4196 },
4197 })
4198 st.wantHeaders(wantHeader{
4199 streamID: 1,
4200 endStream: false,
4201 header: http.Header{
4202 ":status": []string{"200"},
4203 "link": []string{
4204 "</style.css>; rel=preload; as=style",
4205 "</script.js>; rel=preload; as=script",
4206 "</foo.js>; rel=preload; as=script",
4207 },
4208 "content-type": []string{"text/plain; charset=utf-8"},
4209 "content-length": []string{"123"},
4210 },
4211 })
4212 })
4213 }
4214
4215 func TestProtocolErrorAfterGoAway(t *testing.T) {
4216 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4217 io.Copy(io.Discard, r.Body)
4218 })
4219 defer st.Close()
4220
4221 st.greet()
4222 content := "some content"
4223 st.writeHeaders(HeadersFrameParam{
4224 StreamID: 1,
4225 BlockFragment: st.encodeHeader(
4226 ":method", "POST",
4227 "content-length", strconv.Itoa(len(content)),
4228 ),
4229 EndStream: false,
4230 EndHeaders: true,
4231 })
4232 st.writeData(1, false, []byte(content[:5]))
4233
4234
4235
4236 if err := st.fr.WriteGoAway(1, ErrCodeNo, nil); err != nil {
4237 t.Fatal(err)
4238 }
4239 if err := st.fr.WriteWindowUpdate(0, 1<<31-1); err != nil {
4240 t.Fatal(err)
4241 }
4242
4243 st.advance(goAwayTimeout)
4244 st.wantGoAway(1, ErrCodeNo)
4245 st.wantClosed()
4246 }
4247
4248 func TestServerInitialFlowControlWindow(t *testing.T) {
4249 for _, want := range []int32{
4250 65535,
4251 1 << 19,
4252 1 << 21,
4253
4254
4255
4256
4257
4258 65535 * 2,
4259 } {
4260 t.Run(fmt.Sprint(want), func(t *testing.T) {
4261
4262 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4263 }, func(s *Server) {
4264 s.MaxUploadBufferPerConnection = want
4265 })
4266 st.writePreface()
4267 st.writeSettings()
4268 _ = readFrame[*SettingsFrame](t, st)
4269 st.writeSettingsAck()
4270 st.writeHeaders(HeadersFrameParam{
4271 StreamID: 1,
4272 BlockFragment: st.encodeHeader(),
4273 EndStream: true,
4274 EndHeaders: true,
4275 })
4276 window := 65535
4277 Frames:
4278 for {
4279 f := st.readFrame()
4280 switch f := f.(type) {
4281 case *WindowUpdateFrame:
4282 if f.FrameHeader.StreamID != 0 {
4283 t.Errorf("WindowUpdate StreamID = %d; want 0", f.FrameHeader.StreamID)
4284 return
4285 }
4286 window += int(f.Increment)
4287 case *HeadersFrame:
4288 break Frames
4289 case nil:
4290 break Frames
4291 default:
4292 }
4293 }
4294 if window != int(want) {
4295 t.Errorf("got initial flow control window = %v, want %v", window, want)
4296 }
4297 })
4298 }
4299 }
4300
4301
4302
4303 func TestCanonicalHeaderCacheGrowth(t *testing.T) {
4304 for _, size := range []int{1, (1 << 20) - 10} {
4305 base := strings.Repeat("X", size)
4306 sc := &serverConn{
4307 serveG: newGoroutineLock(),
4308 }
4309 count := 0
4310 added := 0
4311 for added < 10*maxCachedCanonicalHeadersKeysSize {
4312 h := fmt.Sprintf("%v-%v", base, count)
4313 c := sc.canonicalHeader(h)
4314 if len(h) != len(c) {
4315 t.Errorf("sc.canonicalHeader(%q) = %q, want same length", h, c)
4316 }
4317 count++
4318 added += len(h)
4319 }
4320 total := 0
4321 for k, v := range sc.canonHeader {
4322 total += len(k) + len(v) + 100
4323 }
4324 if total > maxCachedCanonicalHeadersKeysSize {
4325 t.Errorf("after adding %v ~%v-byte headers, canonHeader cache is ~%v bytes, want <%v", count, size, total, maxCachedCanonicalHeadersKeysSize)
4326 }
4327 }
4328 }
4329
4330
4331
4332
4333
4334
4335 func TestServerWriteDoesNotRetainBufferAfterReturn(t *testing.T) {
4336 donec := make(chan struct{})
4337 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4338 defer close(donec)
4339 buf := make([]byte, 1<<20)
4340 var i byte
4341 for {
4342 i++
4343 _, err := w.Write(buf)
4344 for j := range buf {
4345 buf[j] = byte(i)
4346 }
4347 if err != nil {
4348 return
4349 }
4350 }
4351 })
4352
4353 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4354 defer tr.CloseIdleConnections()
4355
4356 req, _ := http.NewRequest("GET", ts.URL, nil)
4357 res, err := tr.RoundTrip(req)
4358 if err != nil {
4359 t.Fatal(err)
4360 }
4361 res.Body.Close()
4362 <-donec
4363 }
4364
4365
4366
4367
4368
4369
4370 func TestServerWriteDoesNotRetainBufferAfterServerClose(t *testing.T) {
4371 donec := make(chan struct{}, 1)
4372 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4373 donec <- struct{}{}
4374 defer close(donec)
4375 buf := make([]byte, 1<<20)
4376 var i byte
4377 for {
4378 i++
4379 _, err := w.Write(buf)
4380 for j := range buf {
4381 buf[j] = byte(i)
4382 }
4383 if err != nil {
4384 return
4385 }
4386 }
4387 })
4388
4389 tr := &Transport{TLSClientConfig: tlsConfigInsecure}
4390 defer tr.CloseIdleConnections()
4391
4392 req, _ := http.NewRequest("GET", ts.URL, nil)
4393 res, err := tr.RoundTrip(req)
4394 if err != nil {
4395 t.Fatal(err)
4396 }
4397 defer res.Body.Close()
4398 <-donec
4399 ts.Config.Close()
4400 <-donec
4401 }
4402
4403 func TestServerMaxHandlerGoroutines(t *testing.T) {
4404 const maxHandlers = 10
4405 handlerc := make(chan chan bool)
4406 donec := make(chan struct{})
4407 defer close(donec)
4408 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4409 stopc := make(chan bool, 1)
4410 select {
4411 case handlerc <- stopc:
4412 case <-donec:
4413 }
4414 select {
4415 case shouldPanic := <-stopc:
4416 if shouldPanic {
4417 panic(http.ErrAbortHandler)
4418 }
4419 case <-donec:
4420 }
4421 }, func(s *Server) {
4422 s.MaxConcurrentStreams = maxHandlers
4423 })
4424 defer st.Close()
4425
4426 st.greet()
4427
4428
4429
4430 var stops []chan bool
4431 streamID := uint32(1)
4432 for i := 0; i < maxHandlers; i++ {
4433 st.writeHeaders(HeadersFrameParam{
4434 StreamID: streamID,
4435 BlockFragment: st.encodeHeader(),
4436 EndStream: true,
4437 EndHeaders: true,
4438 })
4439 stops = append(stops, <-handlerc)
4440 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4441 streamID += 2
4442 }
4443
4444
4445 st.writeHeaders(HeadersFrameParam{
4446 StreamID: streamID,
4447 BlockFragment: st.encodeHeader(),
4448 EndStream: true,
4449 EndHeaders: true,
4450 })
4451 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4452 streamID += 2
4453
4454
4455 for i := 0; i < 2; i++ {
4456 st.writeHeaders(HeadersFrameParam{
4457 StreamID: streamID,
4458 BlockFragment: st.encodeHeader(),
4459 EndStream: true,
4460 EndHeaders: true,
4461 })
4462 streamID += 2
4463 }
4464
4465
4466
4467 select {
4468 case <-handlerc:
4469 t.Errorf("handler unexpectedly started while maxHandlers are already running")
4470 case <-time.After(1 * time.Millisecond):
4471 }
4472
4473
4474
4475 stops[0] <- false
4476 stops[1] <- true
4477 stops = stops[2:]
4478 stops = append(stops, <-handlerc)
4479 stops = append(stops, <-handlerc)
4480
4481
4482
4483 for i := 0; i < 5*maxHandlers; i++ {
4484 st.writeHeaders(HeadersFrameParam{
4485 StreamID: streamID,
4486 BlockFragment: st.encodeHeader(),
4487 EndStream: true,
4488 EndHeaders: true,
4489 })
4490 st.fr.WriteRSTStream(streamID, ErrCodeCancel)
4491 streamID += 2
4492 }
4493 fr := readFrame[*GoAwayFrame](t, st)
4494 if fr.ErrCode != ErrCodeEnhanceYourCalm {
4495 t.Errorf("err code = %v; want %v", fr.ErrCode, ErrCodeEnhanceYourCalm)
4496 }
4497
4498 for _, s := range stops {
4499 close(s)
4500 }
4501 }
4502
4503 func TestServerContinuationFlood(t *testing.T) {
4504 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4505 fmt.Println(r.Header)
4506 }, func(s *http.Server) {
4507 s.MaxHeaderBytes = 4096
4508 })
4509 defer st.Close()
4510
4511 st.greet()
4512
4513 st.writeHeaders(HeadersFrameParam{
4514 StreamID: 1,
4515 BlockFragment: st.encodeHeader(),
4516 EndStream: true,
4517 })
4518 for i := 0; i < 1000; i++ {
4519 st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
4520 fmt.Sprintf("x-%v", i), "1234567890",
4521 ))
4522 }
4523 st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
4524 "x-last-header", "1",
4525 ))
4526
4527 for {
4528 f := st.readFrame()
4529 if f == nil {
4530 break
4531 }
4532 switch f := f.(type) {
4533 case *HeadersFrame:
4534 t.Fatalf("received HEADERS frame; want GOAWAY and a closed connection")
4535 case *GoAwayFrame:
4536
4537
4538
4539 if got, want := f.LastStreamID, uint32(1); got != want {
4540 t.Errorf("received GOAWAY with LastStreamId %v, want %v", got, want)
4541 }
4542
4543 }
4544 }
4545
4546
4547
4548
4549
4550
4551
4552
4553 }
4554
4555 func TestServerContinuationAfterInvalidHeader(t *testing.T) {
4556 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4557 fmt.Println(r.Header)
4558 })
4559 defer st.Close()
4560
4561 st.greet()
4562
4563 st.writeHeaders(HeadersFrameParam{
4564 StreamID: 1,
4565 BlockFragment: st.encodeHeader(),
4566 EndStream: true,
4567 })
4568 st.fr.WriteContinuation(1, false, st.encodeHeaderRaw(
4569 "x-invalid-header", "\x00",
4570 ))
4571 st.fr.WriteContinuation(1, true, st.encodeHeaderRaw(
4572 "x-valid-header", "1",
4573 ))
4574
4575 var sawGoAway bool
4576 for {
4577 f := st.readFrame()
4578 if f == nil {
4579 break
4580 }
4581 switch f.(type) {
4582 case *GoAwayFrame:
4583 sawGoAway = true
4584 case *HeadersFrame:
4585 t.Fatalf("received HEADERS frame; want GOAWAY")
4586 }
4587 }
4588 if !sawGoAway {
4589 t.Errorf("connection closed with no GOAWAY frame; want one")
4590 }
4591 }
4592
4593 func TestServerUpgradeRequestPrefaceFailure(t *testing.T) {
4594
4595 s2 := &Server{
4596
4597 IdleTimeout: 60 * time.Minute,
4598 }
4599 c1, c2 := net.Pipe()
4600 donec := make(chan struct{})
4601 go func() {
4602 defer close(donec)
4603 s2.ServeConn(c1, &ServeConnOpts{
4604 UpgradeRequest: httptest.NewRequest("GET", "/", nil),
4605 })
4606 }()
4607
4608
4609 c2.Close()
4610 <-donec
4611 }
4612
4613
4614 func TestServerRequestCancelOnError(t *testing.T) {
4615 recvc := make(chan struct{})
4616 donec := make(chan struct{})
4617 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4618 close(recvc)
4619 <-r.Context().Done()
4620 close(donec)
4621 })
4622 defer st.Close()
4623
4624 st.greet()
4625
4626
4627 st.writeHeaders(HeadersFrameParam{
4628 StreamID: 1,
4629 BlockFragment: st.encodeHeader(),
4630 EndStream: true,
4631 EndHeaders: true,
4632 })
4633 <-recvc
4634
4635
4636
4637
4638 st.writeHeaders(HeadersFrameParam{
4639 StreamID: 1,
4640 BlockFragment: st.encodeHeader(),
4641 EndStream: true,
4642 EndHeaders: true,
4643 })
4644 <-donec
4645 }
4646
4647 func TestServerSetReadWriteDeadlineRace(t *testing.T) {
4648 ts := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
4649 ctl := http.NewResponseController(w)
4650 ctl.SetReadDeadline(time.Now().Add(3600 * time.Second))
4651 ctl.SetWriteDeadline(time.Now().Add(3600 * time.Second))
4652 })
4653 resp, err := ts.Client().Get(ts.URL)
4654 if err != nil {
4655 t.Fatal(err)
4656 }
4657 resp.Body.Close()
4658 }
4659
4660 func TestServerWriteByteTimeout(t *testing.T) {
4661 const timeout = 1 * time.Second
4662 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4663 w.Write(make([]byte, 100))
4664 }, func(s *Server) {
4665 s.WriteByteTimeout = timeout
4666 })
4667 st.greet()
4668
4669 st.cc.(*synctestNetConn).SetReadBufferSize(1)
4670 st.writeHeaders(HeadersFrameParam{
4671 StreamID: 1,
4672 BlockFragment: st.encodeHeader(),
4673 EndStream: true,
4674 EndHeaders: true,
4675 })
4676
4677
4678 for i := 0; i < 10; i++ {
4679 st.advance(timeout - 1)
4680 if n, err := st.cc.Read(make([]byte, 1)); n != 1 || err != nil {
4681 t.Fatalf("read %v: %v, %v; want 1, nil", i, n, err)
4682 }
4683 }
4684
4685
4686
4687 st.advance(1 * time.Second)
4688 st.advance(1 * time.Second)
4689 st.wantClosed()
4690 }
4691
4692 func TestServerPingSent(t *testing.T) {
4693 const readIdleTimeout = 15 * time.Second
4694 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4695 }, func(s *Server) {
4696 s.ReadIdleTimeout = readIdleTimeout
4697 })
4698 st.greet()
4699
4700 st.wantIdle()
4701
4702 st.advance(readIdleTimeout)
4703 _ = readFrame[*PingFrame](t, st)
4704 st.wantIdle()
4705
4706 st.advance(14 * time.Second)
4707 st.wantIdle()
4708 st.advance(1 * time.Second)
4709 st.wantClosed()
4710 }
4711
4712 func TestServerPingResponded(t *testing.T) {
4713 const readIdleTimeout = 15 * time.Second
4714 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
4715 }, func(s *Server) {
4716 s.ReadIdleTimeout = readIdleTimeout
4717 })
4718 st.greet()
4719
4720 st.wantIdle()
4721
4722 st.advance(readIdleTimeout)
4723 pf := readFrame[*PingFrame](t, st)
4724 st.wantIdle()
4725
4726 st.advance(14 * time.Second)
4727 st.wantIdle()
4728
4729 st.writePing(true, pf.Data)
4730
4731 st.advance(2 * time.Second)
4732 st.wantIdle()
4733 }
4734
View as plain text