1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26 package http2
27
28 import (
29 "bufio"
30 "bytes"
31 "context"
32 "crypto/rand"
33 "crypto/tls"
34 "errors"
35 "fmt"
36 "io"
37 "log"
38 "math"
39 "net"
40 "net/http"
41 "net/textproto"
42 "net/url"
43 "os"
44 "reflect"
45 "runtime"
46 "strconv"
47 "strings"
48 "sync"
49 "time"
50
51 "golang.org/x/net/http/httpguts"
52 "golang.org/x/net/http2/hpack"
53 )
54
55 const (
56 prefaceTimeout = 10 * time.Second
57 firstSettingsTimeout = 2 * time.Second
58 handlerChunkWriteSize = 4 << 10
59 defaultMaxStreams = 250
60
61
62
63
64 maxQueuedControlFrames = 10000
65 )
66
67 var (
68 errClientDisconnected = errors.New("client disconnected")
69 errClosedBody = errors.New("body closed by handler")
70 errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
71 errStreamClosed = errors.New("http2: stream closed")
72 )
73
74 var responseWriterStatePool = sync.Pool{
75 New: func() interface{} {
76 rws := &responseWriterState{}
77 rws.bw = bufio.NewWriterSize(chunkWriter{rws}, handlerChunkWriteSize)
78 return rws
79 },
80 }
81
82
83 var (
84 testHookOnConn func()
85 testHookGetServerConn func(*serverConn)
86 testHookOnPanicMu *sync.Mutex
87 testHookOnPanic func(sc *serverConn, panicVal interface{}) (rePanic bool)
88 )
89
90
91 type Server struct {
92
93
94
95
96 MaxHandlers int
97
98
99
100
101
102
103
104 MaxConcurrentStreams uint32
105
106
107
108
109
110
111 MaxDecoderHeaderTableSize uint32
112
113
114
115
116
117 MaxEncoderHeaderTableSize uint32
118
119
120
121
122
123 MaxReadFrameSize uint32
124
125
126
127 PermitProhibitedCipherSuites bool
128
129
130
131
132
133 IdleTimeout time.Duration
134
135
136
137
138 ReadIdleTimeout time.Duration
139
140
141
142
143 PingTimeout time.Duration
144
145
146
147
148
149 WriteByteTimeout time.Duration
150
151
152
153
154
155
156 MaxUploadBufferPerConnection int32
157
158
159
160
161
162 MaxUploadBufferPerStream int32
163
164
165
166 NewWriteScheduler func() WriteScheduler
167
168
169
170
171
172 CountError func(errType string)
173
174
175
176
177 state *serverInternalState
178
179
180
181 group synctestGroupInterface
182 }
183
184 func (s *Server) markNewGoroutine() {
185 if s.group != nil {
186 s.group.Join()
187 }
188 }
189
190 func (s *Server) now() time.Time {
191 if s.group != nil {
192 return s.group.Now()
193 }
194 return time.Now()
195 }
196
197
198 func (s *Server) newTimer(d time.Duration) timer {
199 if s.group != nil {
200 return s.group.NewTimer(d)
201 }
202 return timeTimer{time.NewTimer(d)}
203 }
204
205
206 func (s *Server) afterFunc(d time.Duration, f func()) timer {
207 if s.group != nil {
208 return s.group.AfterFunc(d, f)
209 }
210 return timeTimer{time.AfterFunc(d, f)}
211 }
212
213 type serverInternalState struct {
214 mu sync.Mutex
215 activeConns map[*serverConn]struct{}
216 }
217
218 func (s *serverInternalState) registerConn(sc *serverConn) {
219 if s == nil {
220 return
221 }
222 s.mu.Lock()
223 s.activeConns[sc] = struct{}{}
224 s.mu.Unlock()
225 }
226
227 func (s *serverInternalState) unregisterConn(sc *serverConn) {
228 if s == nil {
229 return
230 }
231 s.mu.Lock()
232 delete(s.activeConns, sc)
233 s.mu.Unlock()
234 }
235
236 func (s *serverInternalState) startGracefulShutdown() {
237 if s == nil {
238 return
239 }
240 s.mu.Lock()
241 for sc := range s.activeConns {
242 sc.startGracefulShutdown()
243 }
244 s.mu.Unlock()
245 }
246
247
248
249
250
251
252 func ConfigureServer(s *http.Server, conf *Server) error {
253 if s == nil {
254 panic("nil *http.Server")
255 }
256 if conf == nil {
257 conf = new(Server)
258 }
259 conf.state = &serverInternalState{activeConns: make(map[*serverConn]struct{})}
260 if h1, h2 := s, conf; h2.IdleTimeout == 0 {
261 if h1.IdleTimeout != 0 {
262 h2.IdleTimeout = h1.IdleTimeout
263 } else {
264 h2.IdleTimeout = h1.ReadTimeout
265 }
266 }
267 s.RegisterOnShutdown(conf.state.startGracefulShutdown)
268
269 if s.TLSConfig == nil {
270 s.TLSConfig = new(tls.Config)
271 } else if s.TLSConfig.CipherSuites != nil && s.TLSConfig.MinVersion < tls.VersionTLS13 {
272
273
274
275 haveRequired := false
276 for _, cs := range s.TLSConfig.CipherSuites {
277 switch cs {
278 case tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
279
280
281 tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
282 haveRequired = true
283 }
284 }
285 if !haveRequired {
286 return fmt.Errorf("http2: TLSConfig.CipherSuites is missing an HTTP/2-required AES_128_GCM_SHA256 cipher (need at least one of TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 or TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256)")
287 }
288 }
289
290
291
292
293
294
295
296
297 s.TLSConfig.PreferServerCipherSuites = true
298
299 if !strSliceContains(s.TLSConfig.NextProtos, NextProtoTLS) {
300 s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, NextProtoTLS)
301 }
302 if !strSliceContains(s.TLSConfig.NextProtos, "http/1.1") {
303 s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "http/1.1")
304 }
305
306 if s.TLSNextProto == nil {
307 s.TLSNextProto = map[string]func(*http.Server, *tls.Conn, http.Handler){}
308 }
309 protoHandler := func(hs *http.Server, c net.Conn, h http.Handler, sawClientPreface bool) {
310 if testHookOnConn != nil {
311 testHookOnConn()
312 }
313
314
315
316
317
318 var ctx context.Context
319 type baseContexter interface {
320 BaseContext() context.Context
321 }
322 if bc, ok := h.(baseContexter); ok {
323 ctx = bc.BaseContext()
324 }
325 conf.ServeConn(c, &ServeConnOpts{
326 Context: ctx,
327 Handler: h,
328 BaseConfig: hs,
329 SawClientPreface: sawClientPreface,
330 })
331 }
332 s.TLSNextProto[NextProtoTLS] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
333 protoHandler(hs, c, h, false)
334 }
335
336
337
338 s.TLSNextProto[nextProtoUnencryptedHTTP2] = func(hs *http.Server, c *tls.Conn, h http.Handler) {
339 nc, err := unencryptedNetConnFromTLSConn(c)
340 if err != nil {
341 if lg := hs.ErrorLog; lg != nil {
342 lg.Print(err)
343 } else {
344 log.Print(err)
345 }
346 go c.Close()
347 return
348 }
349 protoHandler(hs, nc, h, true)
350 }
351 return nil
352 }
353
354
355 type ServeConnOpts struct {
356
357
358 Context context.Context
359
360
361
362 BaseConfig *http.Server
363
364
365
366
367 Handler http.Handler
368
369
370
371
372
373 UpgradeRequest *http.Request
374
375
376
377 Settings []byte
378
379
380
381 SawClientPreface bool
382 }
383
384 func (o *ServeConnOpts) context() context.Context {
385 if o != nil && o.Context != nil {
386 return o.Context
387 }
388 return context.Background()
389 }
390
391 func (o *ServeConnOpts) baseConfig() *http.Server {
392 if o != nil && o.BaseConfig != nil {
393 return o.BaseConfig
394 }
395 return new(http.Server)
396 }
397
398 func (o *ServeConnOpts) handler() http.Handler {
399 if o != nil {
400 if o.Handler != nil {
401 return o.Handler
402 }
403 if o.BaseConfig != nil && o.BaseConfig.Handler != nil {
404 return o.BaseConfig.Handler
405 }
406 }
407 return http.DefaultServeMux
408 }
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424 func (s *Server) ServeConn(c net.Conn, opts *ServeConnOpts) {
425 s.serveConn(c, opts, nil)
426 }
427
428 func (s *Server) serveConn(c net.Conn, opts *ServeConnOpts, newf func(*serverConn)) {
429 baseCtx, cancel := serverConnBaseContext(c, opts)
430 defer cancel()
431
432 http1srv := opts.baseConfig()
433 conf := configFromServer(http1srv, s)
434 sc := &serverConn{
435 srv: s,
436 hs: http1srv,
437 conn: c,
438 baseCtx: baseCtx,
439 remoteAddrStr: c.RemoteAddr().String(),
440 bw: newBufferedWriter(s.group, c, conf.WriteByteTimeout),
441 handler: opts.handler(),
442 streams: make(map[uint32]*stream),
443 readFrameCh: make(chan readFrameResult),
444 wantWriteFrameCh: make(chan FrameWriteRequest, 8),
445 serveMsgCh: make(chan interface{}, 8),
446 wroteFrameCh: make(chan frameWriteResult, 1),
447 bodyReadCh: make(chan bodyReadMsg),
448 doneServing: make(chan struct{}),
449 clientMaxStreams: math.MaxUint32,
450 advMaxStreams: conf.MaxConcurrentStreams,
451 initialStreamSendWindowSize: initialWindowSize,
452 initialStreamRecvWindowSize: conf.MaxUploadBufferPerStream,
453 maxFrameSize: initialMaxFrameSize,
454 pingTimeout: conf.PingTimeout,
455 countErrorFunc: conf.CountError,
456 serveG: newGoroutineLock(),
457 pushEnabled: true,
458 sawClientPreface: opts.SawClientPreface,
459 }
460 if newf != nil {
461 newf(sc)
462 }
463
464 s.state.registerConn(sc)
465 defer s.state.unregisterConn(sc)
466
467
468
469
470
471
472 if sc.hs.WriteTimeout > 0 {
473 sc.conn.SetWriteDeadline(time.Time{})
474 }
475
476 if s.NewWriteScheduler != nil {
477 sc.writeSched = s.NewWriteScheduler()
478 } else {
479 sc.writeSched = newRoundRobinWriteScheduler()
480 }
481
482
483
484
485 sc.flow.add(initialWindowSize)
486 sc.inflow.init(initialWindowSize)
487 sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
488 sc.hpackEncoder.SetMaxDynamicTableSizeLimit(conf.MaxEncoderHeaderTableSize)
489
490 fr := NewFramer(sc.bw, c)
491 if conf.CountError != nil {
492 fr.countError = conf.CountError
493 }
494 fr.ReadMetaHeaders = hpack.NewDecoder(conf.MaxDecoderHeaderTableSize, nil)
495 fr.MaxHeaderListSize = sc.maxHeaderListSize()
496 fr.SetMaxReadFrameSize(conf.MaxReadFrameSize)
497 sc.framer = fr
498
499 if tc, ok := c.(connectionStater); ok {
500 sc.tlsState = new(tls.ConnectionState)
501 *sc.tlsState = tc.ConnectionState()
502
503
504
505
506
507
508
509
510
511
512 if sc.tlsState.Version < tls.VersionTLS12 {
513 sc.rejectConn(ErrCodeInadequateSecurity, "TLS version too low")
514 return
515 }
516
517 if sc.tlsState.ServerName == "" {
518
519
520
521
522
523
524
525
526
527 }
528
529 if !conf.PermitProhibitedCipherSuites && isBadCipher(sc.tlsState.CipherSuite) {
530
531
532
533
534
535
536
537
538
539
540 sc.rejectConn(ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite))
541 return
542 }
543 }
544
545 if opts.Settings != nil {
546 fr := &SettingsFrame{
547 FrameHeader: FrameHeader{valid: true},
548 p: opts.Settings,
549 }
550 if err := fr.ForeachSetting(sc.processSetting); err != nil {
551 sc.rejectConn(ErrCodeProtocol, "invalid settings")
552 return
553 }
554 opts.Settings = nil
555 }
556
557 if hook := testHookGetServerConn; hook != nil {
558 hook(sc)
559 }
560
561 if opts.UpgradeRequest != nil {
562 sc.upgradeRequest(opts.UpgradeRequest)
563 opts.UpgradeRequest = nil
564 }
565
566 sc.serve(conf)
567 }
568
569 func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx context.Context, cancel func()) {
570 ctx, cancel = context.WithCancel(opts.context())
571 ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
572 if hs := opts.baseConfig(); hs != nil {
573 ctx = context.WithValue(ctx, http.ServerContextKey, hs)
574 }
575 return
576 }
577
578 func (sc *serverConn) rejectConn(err ErrCode, debug string) {
579 sc.vlogf("http2: server rejecting conn: %v, %s", err, debug)
580
581 sc.framer.WriteGoAway(0, err, []byte(debug))
582 sc.bw.Flush()
583 sc.conn.Close()
584 }
585
586 type serverConn struct {
587
588 srv *Server
589 hs *http.Server
590 conn net.Conn
591 bw *bufferedWriter
592 handler http.Handler
593 baseCtx context.Context
594 framer *Framer
595 doneServing chan struct{}
596 readFrameCh chan readFrameResult
597 wantWriteFrameCh chan FrameWriteRequest
598 wroteFrameCh chan frameWriteResult
599 bodyReadCh chan bodyReadMsg
600 serveMsgCh chan interface{}
601 flow outflow
602 inflow inflow
603 tlsState *tls.ConnectionState
604 remoteAddrStr string
605 writeSched WriteScheduler
606 countErrorFunc func(errType string)
607
608
609 serveG goroutineLock
610 pushEnabled bool
611 sawClientPreface bool
612 sawFirstSettings bool
613 needToSendSettingsAck bool
614 unackedSettings int
615 queuedControlFrames int
616 clientMaxStreams uint32
617 advMaxStreams uint32
618 curClientStreams uint32
619 curPushedStreams uint32
620 curHandlers uint32
621 maxClientStreamID uint32
622 maxPushPromiseID uint32
623 streams map[uint32]*stream
624 unstartedHandlers []unstartedHandler
625 initialStreamSendWindowSize int32
626 initialStreamRecvWindowSize int32
627 maxFrameSize int32
628 peerMaxHeaderListSize uint32
629 canonHeader map[string]string
630 canonHeaderKeysSize int
631 writingFrame bool
632 writingFrameAsync bool
633 needsFrameFlush bool
634 inGoAway bool
635 inFrameScheduleLoop bool
636 needToSendGoAway bool
637 pingSent bool
638 sentPingData [8]byte
639 goAwayCode ErrCode
640 shutdownTimer timer
641 idleTimer timer
642 readIdleTimeout time.Duration
643 pingTimeout time.Duration
644 readIdleTimer timer
645
646
647 headerWriteBuf bytes.Buffer
648 hpackEncoder *hpack.Encoder
649
650
651 shutdownOnce sync.Once
652 }
653
654 func (sc *serverConn) maxHeaderListSize() uint32 {
655 n := sc.hs.MaxHeaderBytes
656 if n <= 0 {
657 n = http.DefaultMaxHeaderBytes
658 }
659 return uint32(adjustHTTP1MaxHeaderSize(int64(n)))
660 }
661
662 func (sc *serverConn) curOpenStreams() uint32 {
663 sc.serveG.check()
664 return sc.curClientStreams + sc.curPushedStreams
665 }
666
667
668
669
670
671
672
673
674 type stream struct {
675
676 sc *serverConn
677 id uint32
678 body *pipe
679 cw closeWaiter
680 ctx context.Context
681 cancelCtx func()
682
683
684 bodyBytes int64
685 declBodyBytes int64
686 flow outflow
687 inflow inflow
688 state streamState
689 resetQueued bool
690 gotTrailerHeader bool
691 wroteHeaders bool
692 readDeadline timer
693 writeDeadline timer
694 closeErr error
695
696 trailer http.Header
697 reqTrailer http.Header
698 }
699
700 func (sc *serverConn) Framer() *Framer { return sc.framer }
701 func (sc *serverConn) CloseConn() error { return sc.conn.Close() }
702 func (sc *serverConn) Flush() error { return sc.bw.Flush() }
703 func (sc *serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
704 return sc.hpackEncoder, &sc.headerWriteBuf
705 }
706
707 func (sc *serverConn) state(streamID uint32) (streamState, *stream) {
708 sc.serveG.check()
709
710 if st, ok := sc.streams[streamID]; ok {
711 return st.state, st
712 }
713
714
715
716
717
718
719 if streamID%2 == 1 {
720 if streamID <= sc.maxClientStreamID {
721 return stateClosed, nil
722 }
723 } else {
724 if streamID <= sc.maxPushPromiseID {
725 return stateClosed, nil
726 }
727 }
728 return stateIdle, nil
729 }
730
731
732
733
734 func (sc *serverConn) setConnState(state http.ConnState) {
735 if sc.hs.ConnState != nil {
736 sc.hs.ConnState(sc.conn, state)
737 }
738 }
739
740 func (sc *serverConn) vlogf(format string, args ...interface{}) {
741 if VerboseLogs {
742 sc.logf(format, args...)
743 }
744 }
745
746 func (sc *serverConn) logf(format string, args ...interface{}) {
747 if lg := sc.hs.ErrorLog; lg != nil {
748 lg.Printf(format, args...)
749 } else {
750 log.Printf(format, args...)
751 }
752 }
753
754
755
756
757
758 func errno(v error) uintptr {
759 if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
760 return uintptr(rv.Uint())
761 }
762 return 0
763 }
764
765
766
767 func isClosedConnError(err error) bool {
768 if err == nil {
769 return false
770 }
771
772 if errors.Is(err, net.ErrClosed) {
773 return true
774 }
775
776
777
778
779
780 if runtime.GOOS == "windows" {
781 if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
782 if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
783 const WSAECONNABORTED = 10053
784 const WSAECONNRESET = 10054
785 if n := errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
786 return true
787 }
788 }
789 }
790 }
791 return false
792 }
793
794 func (sc *serverConn) condlogf(err error, format string, args ...interface{}) {
795 if err == nil {
796 return
797 }
798 if err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err) || err == errPrefaceTimeout {
799
800 sc.vlogf(format, args...)
801 } else {
802 sc.logf(format, args...)
803 }
804 }
805
806
807
808
809
810
811 const maxCachedCanonicalHeadersKeysSize = 2048
812
813 func (sc *serverConn) canonicalHeader(v string) string {
814 sc.serveG.check()
815 buildCommonHeaderMapsOnce()
816 cv, ok := commonCanonHeader[v]
817 if ok {
818 return cv
819 }
820 cv, ok = sc.canonHeader[v]
821 if ok {
822 return cv
823 }
824 if sc.canonHeader == nil {
825 sc.canonHeader = make(map[string]string)
826 }
827 cv = http.CanonicalHeaderKey(v)
828 size := 100 + len(v)*2
829 if sc.canonHeaderKeysSize+size <= maxCachedCanonicalHeadersKeysSize {
830 sc.canonHeader[v] = cv
831 sc.canonHeaderKeysSize += size
832 }
833 return cv
834 }
835
836 type readFrameResult struct {
837 f Frame
838 err error
839
840
841
842
843 readMore func()
844 }
845
846
847
848
849
850 func (sc *serverConn) readFrames() {
851 sc.srv.markNewGoroutine()
852 gate := make(chan struct{})
853 gateDone := func() { gate <- struct{}{} }
854 for {
855 f, err := sc.framer.ReadFrame()
856 select {
857 case sc.readFrameCh <- readFrameResult{f, err, gateDone}:
858 case <-sc.doneServing:
859 return
860 }
861 select {
862 case <-gate:
863 case <-sc.doneServing:
864 return
865 }
866 if terminalReadFrameError(err) {
867 return
868 }
869 }
870 }
871
872
873 type frameWriteResult struct {
874 _ incomparable
875 wr FrameWriteRequest
876 err error
877 }
878
879
880
881
882
883 func (sc *serverConn) writeFrameAsync(wr FrameWriteRequest, wd *writeData) {
884 sc.srv.markNewGoroutine()
885 var err error
886 if wd == nil {
887 err = wr.write.writeFrame(sc)
888 } else {
889 err = sc.framer.endWrite()
890 }
891 sc.wroteFrameCh <- frameWriteResult{wr: wr, err: err}
892 }
893
894 func (sc *serverConn) closeAllStreamsOnConnClose() {
895 sc.serveG.check()
896 for _, st := range sc.streams {
897 sc.closeStream(st, errClientDisconnected)
898 }
899 }
900
901 func (sc *serverConn) stopShutdownTimer() {
902 sc.serveG.check()
903 if t := sc.shutdownTimer; t != nil {
904 t.Stop()
905 }
906 }
907
908 func (sc *serverConn) notePanic() {
909
910 if testHookOnPanicMu != nil {
911 testHookOnPanicMu.Lock()
912 defer testHookOnPanicMu.Unlock()
913 }
914 if testHookOnPanic != nil {
915 if e := recover(); e != nil {
916 if testHookOnPanic(sc, e) {
917 panic(e)
918 }
919 }
920 }
921 }
922
923 func (sc *serverConn) serve(conf http2Config) {
924 sc.serveG.check()
925 defer sc.notePanic()
926 defer sc.conn.Close()
927 defer sc.closeAllStreamsOnConnClose()
928 defer sc.stopShutdownTimer()
929 defer close(sc.doneServing)
930
931 if VerboseLogs {
932 sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
933 }
934
935 settings := writeSettings{
936 {SettingMaxFrameSize, conf.MaxReadFrameSize},
937 {SettingMaxConcurrentStreams, sc.advMaxStreams},
938 {SettingMaxHeaderListSize, sc.maxHeaderListSize()},
939 {SettingHeaderTableSize, conf.MaxDecoderHeaderTableSize},
940 {SettingInitialWindowSize, uint32(sc.initialStreamRecvWindowSize)},
941 }
942 if !disableExtendedConnectProtocol {
943 settings = append(settings, Setting{SettingEnableConnectProtocol, 1})
944 }
945 sc.writeFrame(FrameWriteRequest{
946 write: settings,
947 })
948 sc.unackedSettings++
949
950
951
952 if diff := conf.MaxUploadBufferPerConnection - initialWindowSize; diff > 0 {
953 sc.sendWindowUpdate(nil, int(diff))
954 }
955
956 if err := sc.readPreface(); err != nil {
957 sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
958 return
959 }
960
961
962
963
964 sc.setConnState(http.StateActive)
965 sc.setConnState(http.StateIdle)
966
967 if sc.srv.IdleTimeout > 0 {
968 sc.idleTimer = sc.srv.afterFunc(sc.srv.IdleTimeout, sc.onIdleTimer)
969 defer sc.idleTimer.Stop()
970 }
971
972 if conf.SendPingTimeout > 0 {
973 sc.readIdleTimeout = conf.SendPingTimeout
974 sc.readIdleTimer = sc.srv.afterFunc(conf.SendPingTimeout, sc.onReadIdleTimer)
975 defer sc.readIdleTimer.Stop()
976 }
977
978 go sc.readFrames()
979
980 settingsTimer := sc.srv.afterFunc(firstSettingsTimeout, sc.onSettingsTimer)
981 defer settingsTimer.Stop()
982
983 lastFrameTime := sc.srv.now()
984 loopNum := 0
985 for {
986 loopNum++
987 select {
988 case wr := <-sc.wantWriteFrameCh:
989 if se, ok := wr.write.(StreamError); ok {
990 sc.resetStream(se)
991 break
992 }
993 sc.writeFrame(wr)
994 case res := <-sc.wroteFrameCh:
995 sc.wroteFrame(res)
996 case res := <-sc.readFrameCh:
997 lastFrameTime = sc.srv.now()
998
999
1000 if sc.writingFrameAsync {
1001 select {
1002 case wroteRes := <-sc.wroteFrameCh:
1003 sc.wroteFrame(wroteRes)
1004 default:
1005 }
1006 }
1007 if !sc.processFrameFromReader(res) {
1008 return
1009 }
1010 res.readMore()
1011 if settingsTimer != nil {
1012 settingsTimer.Stop()
1013 settingsTimer = nil
1014 }
1015 case m := <-sc.bodyReadCh:
1016 sc.noteBodyRead(m.st, m.n)
1017 case msg := <-sc.serveMsgCh:
1018 switch v := msg.(type) {
1019 case func(int):
1020 v(loopNum)
1021 case *serverMessage:
1022 switch v {
1023 case settingsTimerMsg:
1024 sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
1025 return
1026 case idleTimerMsg:
1027 sc.vlogf("connection is idle")
1028 sc.goAway(ErrCodeNo)
1029 case readIdleTimerMsg:
1030 sc.handlePingTimer(lastFrameTime)
1031 case shutdownTimerMsg:
1032 sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
1033 return
1034 case gracefulShutdownMsg:
1035 sc.startGracefulShutdownInternal()
1036 case handlerDoneMsg:
1037 sc.handlerDone()
1038 default:
1039 panic("unknown timer")
1040 }
1041 case *startPushRequest:
1042 sc.startPush(v)
1043 case func(*serverConn):
1044 v(sc)
1045 default:
1046 panic(fmt.Sprintf("unexpected type %T", v))
1047 }
1048 }
1049
1050
1051
1052
1053 if sc.queuedControlFrames > maxQueuedControlFrames {
1054 sc.vlogf("http2: too many control frames in send queue, closing connection")
1055 return
1056 }
1057
1058
1059
1060
1061 sentGoAway := sc.inGoAway && !sc.needToSendGoAway && !sc.writingFrame
1062 gracefulShutdownComplete := sc.goAwayCode == ErrCodeNo && sc.curOpenStreams() == 0
1063 if sentGoAway && sc.shutdownTimer == nil && (sc.goAwayCode != ErrCodeNo || gracefulShutdownComplete) {
1064 sc.shutDownIn(goAwayTimeout)
1065 }
1066 }
1067 }
1068
1069 func (sc *serverConn) handlePingTimer(lastFrameReadTime time.Time) {
1070 if sc.pingSent {
1071 sc.vlogf("timeout waiting for PING response")
1072 sc.conn.Close()
1073 return
1074 }
1075
1076 pingAt := lastFrameReadTime.Add(sc.readIdleTimeout)
1077 now := sc.srv.now()
1078 if pingAt.After(now) {
1079
1080
1081 sc.readIdleTimer.Reset(pingAt.Sub(now))
1082 return
1083 }
1084
1085 sc.pingSent = true
1086
1087
1088 _, _ = rand.Read(sc.sentPingData[:])
1089 sc.writeFrame(FrameWriteRequest{
1090 write: &writePing{data: sc.sentPingData},
1091 })
1092 sc.readIdleTimer.Reset(sc.pingTimeout)
1093 }
1094
1095 type serverMessage int
1096
1097
1098 var (
1099 settingsTimerMsg = new(serverMessage)
1100 idleTimerMsg = new(serverMessage)
1101 readIdleTimerMsg = new(serverMessage)
1102 shutdownTimerMsg = new(serverMessage)
1103 gracefulShutdownMsg = new(serverMessage)
1104 handlerDoneMsg = new(serverMessage)
1105 )
1106
1107 func (sc *serverConn) onSettingsTimer() { sc.sendServeMsg(settingsTimerMsg) }
1108 func (sc *serverConn) onIdleTimer() { sc.sendServeMsg(idleTimerMsg) }
1109 func (sc *serverConn) onReadIdleTimer() { sc.sendServeMsg(readIdleTimerMsg) }
1110 func (sc *serverConn) onShutdownTimer() { sc.sendServeMsg(shutdownTimerMsg) }
1111
1112 func (sc *serverConn) sendServeMsg(msg interface{}) {
1113 sc.serveG.checkNotOn()
1114 select {
1115 case sc.serveMsgCh <- msg:
1116 case <-sc.doneServing:
1117 }
1118 }
1119
1120 var errPrefaceTimeout = errors.New("timeout waiting for client preface")
1121
1122
1123
1124
1125 func (sc *serverConn) readPreface() error {
1126 if sc.sawClientPreface {
1127 return nil
1128 }
1129 errc := make(chan error, 1)
1130 go func() {
1131
1132 buf := make([]byte, len(ClientPreface))
1133 if _, err := io.ReadFull(sc.conn, buf); err != nil {
1134 errc <- err
1135 } else if !bytes.Equal(buf, clientPreface) {
1136 errc <- fmt.Errorf("bogus greeting %q", buf)
1137 } else {
1138 errc <- nil
1139 }
1140 }()
1141 timer := sc.srv.newTimer(prefaceTimeout)
1142 defer timer.Stop()
1143 select {
1144 case <-timer.C():
1145 return errPrefaceTimeout
1146 case err := <-errc:
1147 if err == nil {
1148 if VerboseLogs {
1149 sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr())
1150 }
1151 }
1152 return err
1153 }
1154 }
1155
1156 var errChanPool = sync.Pool{
1157 New: func() interface{} { return make(chan error, 1) },
1158 }
1159
1160 var writeDataPool = sync.Pool{
1161 New: func() interface{} { return new(writeData) },
1162 }
1163
1164
1165
1166 func (sc *serverConn) writeDataFromHandler(stream *stream, data []byte, endStream bool) error {
1167 ch := errChanPool.Get().(chan error)
1168 writeArg := writeDataPool.Get().(*writeData)
1169 *writeArg = writeData{stream.id, data, endStream}
1170 err := sc.writeFrameFromHandler(FrameWriteRequest{
1171 write: writeArg,
1172 stream: stream,
1173 done: ch,
1174 })
1175 if err != nil {
1176 return err
1177 }
1178 var frameWriteDone bool
1179 select {
1180 case err = <-ch:
1181 frameWriteDone = true
1182 case <-sc.doneServing:
1183 return errClientDisconnected
1184 case <-stream.cw:
1185
1186
1187
1188
1189
1190
1191
1192 select {
1193 case err = <-ch:
1194 frameWriteDone = true
1195 default:
1196 return errStreamClosed
1197 }
1198 }
1199 errChanPool.Put(ch)
1200 if frameWriteDone {
1201 writeDataPool.Put(writeArg)
1202 }
1203 return err
1204 }
1205
1206
1207
1208
1209
1210
1211
1212
1213 func (sc *serverConn) writeFrameFromHandler(wr FrameWriteRequest) error {
1214 sc.serveG.checkNotOn()
1215 select {
1216 case sc.wantWriteFrameCh <- wr:
1217 return nil
1218 case <-sc.doneServing:
1219
1220
1221 return errClientDisconnected
1222 }
1223 }
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233 func (sc *serverConn) writeFrame(wr FrameWriteRequest) {
1234 sc.serveG.check()
1235
1236
1237 var ignoreWrite bool
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257 if wr.StreamID() != 0 {
1258 _, isReset := wr.write.(StreamError)
1259 if state, _ := sc.state(wr.StreamID()); state == stateClosed && !isReset {
1260 ignoreWrite = true
1261 }
1262 }
1263
1264
1265
1266 switch wr.write.(type) {
1267 case *writeResHeaders:
1268 wr.stream.wroteHeaders = true
1269 case write100ContinueHeadersFrame:
1270 if wr.stream.wroteHeaders {
1271
1272
1273 if wr.done != nil {
1274 panic("wr.done != nil for write100ContinueHeadersFrame")
1275 }
1276 ignoreWrite = true
1277 }
1278 }
1279
1280 if !ignoreWrite {
1281 if wr.isControl() {
1282 sc.queuedControlFrames++
1283
1284
1285 if sc.queuedControlFrames < 0 {
1286 sc.conn.Close()
1287 }
1288 }
1289 sc.writeSched.Push(wr)
1290 }
1291 sc.scheduleFrameWrite()
1292 }
1293
1294
1295
1296
1297 func (sc *serverConn) startFrameWrite(wr FrameWriteRequest) {
1298 sc.serveG.check()
1299 if sc.writingFrame {
1300 panic("internal error: can only be writing one frame at a time")
1301 }
1302
1303 st := wr.stream
1304 if st != nil {
1305 switch st.state {
1306 case stateHalfClosedLocal:
1307 switch wr.write.(type) {
1308 case StreamError, handlerPanicRST, writeWindowUpdate:
1309
1310
1311 default:
1312 panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr))
1313 }
1314 case stateClosed:
1315 panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr))
1316 }
1317 }
1318 if wpp, ok := wr.write.(*writePushPromise); ok {
1319 var err error
1320 wpp.promisedID, err = wpp.allocatePromisedID()
1321 if err != nil {
1322 sc.writingFrameAsync = false
1323 wr.replyToWriter(err)
1324 return
1325 }
1326 }
1327
1328 sc.writingFrame = true
1329 sc.needsFrameFlush = true
1330 if wr.write.staysWithinBuffer(sc.bw.Available()) {
1331 sc.writingFrameAsync = false
1332 err := wr.write.writeFrame(sc)
1333 sc.wroteFrame(frameWriteResult{wr: wr, err: err})
1334 } else if wd, ok := wr.write.(*writeData); ok {
1335
1336
1337
1338 sc.framer.startWriteDataPadded(wd.streamID, wd.endStream, wd.p, nil)
1339 sc.writingFrameAsync = true
1340 go sc.writeFrameAsync(wr, wd)
1341 } else {
1342 sc.writingFrameAsync = true
1343 go sc.writeFrameAsync(wr, nil)
1344 }
1345 }
1346
1347
1348
1349
1350 var errHandlerPanicked = errors.New("http2: handler panicked")
1351
1352
1353
1354 func (sc *serverConn) wroteFrame(res frameWriteResult) {
1355 sc.serveG.check()
1356 if !sc.writingFrame {
1357 panic("internal error: expected to be already writing a frame")
1358 }
1359 sc.writingFrame = false
1360 sc.writingFrameAsync = false
1361
1362 if res.err != nil {
1363 sc.conn.Close()
1364 }
1365
1366 wr := res.wr
1367
1368 if writeEndsStream(wr.write) {
1369 st := wr.stream
1370 if st == nil {
1371 panic("internal error: expecting non-nil stream")
1372 }
1373 switch st.state {
1374 case stateOpen:
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385 st.state = stateHalfClosedLocal
1386
1387
1388
1389
1390 sc.resetStream(streamError(st.id, ErrCodeNo))
1391 case stateHalfClosedRemote:
1392 sc.closeStream(st, errHandlerComplete)
1393 }
1394 } else {
1395 switch v := wr.write.(type) {
1396 case StreamError:
1397
1398 if st, ok := sc.streams[v.StreamID]; ok {
1399 sc.closeStream(st, v)
1400 }
1401 case handlerPanicRST:
1402 sc.closeStream(wr.stream, errHandlerPanicked)
1403 }
1404 }
1405
1406
1407 wr.replyToWriter(res.err)
1408
1409 sc.scheduleFrameWrite()
1410 }
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422 func (sc *serverConn) scheduleFrameWrite() {
1423 sc.serveG.check()
1424 if sc.writingFrame || sc.inFrameScheduleLoop {
1425 return
1426 }
1427 sc.inFrameScheduleLoop = true
1428 for !sc.writingFrameAsync {
1429 if sc.needToSendGoAway {
1430 sc.needToSendGoAway = false
1431 sc.startFrameWrite(FrameWriteRequest{
1432 write: &writeGoAway{
1433 maxStreamID: sc.maxClientStreamID,
1434 code: sc.goAwayCode,
1435 },
1436 })
1437 continue
1438 }
1439 if sc.needToSendSettingsAck {
1440 sc.needToSendSettingsAck = false
1441 sc.startFrameWrite(FrameWriteRequest{write: writeSettingsAck{}})
1442 continue
1443 }
1444 if !sc.inGoAway || sc.goAwayCode == ErrCodeNo {
1445 if wr, ok := sc.writeSched.Pop(); ok {
1446 if wr.isControl() {
1447 sc.queuedControlFrames--
1448 }
1449 sc.startFrameWrite(wr)
1450 continue
1451 }
1452 }
1453 if sc.needsFrameFlush {
1454 sc.startFrameWrite(FrameWriteRequest{write: flushFrameWriter{}})
1455 sc.needsFrameFlush = false
1456 continue
1457 }
1458 break
1459 }
1460 sc.inFrameScheduleLoop = false
1461 }
1462
1463
1464
1465
1466
1467
1468
1469
1470 func (sc *serverConn) startGracefulShutdown() {
1471 sc.serveG.checkNotOn()
1472 sc.shutdownOnce.Do(func() { sc.sendServeMsg(gracefulShutdownMsg) })
1473 }
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491 var goAwayTimeout = 1 * time.Second
1492
1493 func (sc *serverConn) startGracefulShutdownInternal() {
1494 sc.goAway(ErrCodeNo)
1495 }
1496
1497 func (sc *serverConn) goAway(code ErrCode) {
1498 sc.serveG.check()
1499 if sc.inGoAway {
1500 if sc.goAwayCode == ErrCodeNo {
1501 sc.goAwayCode = code
1502 }
1503 return
1504 }
1505 sc.inGoAway = true
1506 sc.needToSendGoAway = true
1507 sc.goAwayCode = code
1508 sc.scheduleFrameWrite()
1509 }
1510
1511 func (sc *serverConn) shutDownIn(d time.Duration) {
1512 sc.serveG.check()
1513 sc.shutdownTimer = sc.srv.afterFunc(d, sc.onShutdownTimer)
1514 }
1515
1516 func (sc *serverConn) resetStream(se StreamError) {
1517 sc.serveG.check()
1518 sc.writeFrame(FrameWriteRequest{write: se})
1519 if st, ok := sc.streams[se.StreamID]; ok {
1520 st.resetQueued = true
1521 }
1522 }
1523
1524
1525
1526
1527 func (sc *serverConn) processFrameFromReader(res readFrameResult) bool {
1528 sc.serveG.check()
1529 err := res.err
1530 if err != nil {
1531 if err == ErrFrameTooLarge {
1532 sc.goAway(ErrCodeFrameSize)
1533 return true
1534 }
1535 clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || isClosedConnError(err)
1536 if clientGone {
1537
1538
1539
1540
1541
1542
1543
1544
1545 return false
1546 }
1547 } else {
1548 f := res.f
1549 if VerboseLogs {
1550 sc.vlogf("http2: server read frame %v", summarizeFrame(f))
1551 }
1552 err = sc.processFrame(f)
1553 if err == nil {
1554 return true
1555 }
1556 }
1557
1558 switch ev := err.(type) {
1559 case StreamError:
1560 sc.resetStream(ev)
1561 return true
1562 case goAwayFlowError:
1563 sc.goAway(ErrCodeFlowControl)
1564 return true
1565 case ConnectionError:
1566 if res.f != nil {
1567 if id := res.f.Header().StreamID; id > sc.maxClientStreamID {
1568 sc.maxClientStreamID = id
1569 }
1570 }
1571 sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
1572 sc.goAway(ErrCode(ev))
1573 return true
1574 default:
1575 if res.err != nil {
1576 sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
1577 } else {
1578 sc.logf("http2: server closing client connection: %v", err)
1579 }
1580 return false
1581 }
1582 }
1583
1584 func (sc *serverConn) processFrame(f Frame) error {
1585 sc.serveG.check()
1586
1587
1588 if !sc.sawFirstSettings {
1589 if _, ok := f.(*SettingsFrame); !ok {
1590 return sc.countError("first_settings", ConnectionError(ErrCodeProtocol))
1591 }
1592 sc.sawFirstSettings = true
1593 }
1594
1595
1596
1597
1598
1599 if sc.inGoAway && (sc.goAwayCode != ErrCodeNo || f.Header().StreamID > sc.maxClientStreamID) {
1600
1601 if f, ok := f.(*DataFrame); ok {
1602 if !sc.inflow.take(f.Length) {
1603 return sc.countError("data_flow", streamError(f.Header().StreamID, ErrCodeFlowControl))
1604 }
1605 sc.sendWindowUpdate(nil, int(f.Length))
1606 }
1607 return nil
1608 }
1609
1610 switch f := f.(type) {
1611 case *SettingsFrame:
1612 return sc.processSettings(f)
1613 case *MetaHeadersFrame:
1614 return sc.processHeaders(f)
1615 case *WindowUpdateFrame:
1616 return sc.processWindowUpdate(f)
1617 case *PingFrame:
1618 return sc.processPing(f)
1619 case *DataFrame:
1620 return sc.processData(f)
1621 case *RSTStreamFrame:
1622 return sc.processResetStream(f)
1623 case *PriorityFrame:
1624 return sc.processPriority(f)
1625 case *GoAwayFrame:
1626 return sc.processGoAway(f)
1627 case *PushPromiseFrame:
1628
1629
1630 return sc.countError("push_promise", ConnectionError(ErrCodeProtocol))
1631 default:
1632 sc.vlogf("http2: server ignoring frame: %v", f.Header())
1633 return nil
1634 }
1635 }
1636
1637 func (sc *serverConn) processPing(f *PingFrame) error {
1638 sc.serveG.check()
1639 if f.IsAck() {
1640 if sc.pingSent && sc.sentPingData == f.Data {
1641
1642 sc.pingSent = false
1643 sc.readIdleTimer.Reset(sc.readIdleTimeout)
1644 }
1645
1646
1647 return nil
1648 }
1649 if f.StreamID != 0 {
1650
1651
1652
1653
1654
1655 return sc.countError("ping_on_stream", ConnectionError(ErrCodeProtocol))
1656 }
1657 sc.writeFrame(FrameWriteRequest{write: writePingAck{f}})
1658 return nil
1659 }
1660
1661 func (sc *serverConn) processWindowUpdate(f *WindowUpdateFrame) error {
1662 sc.serveG.check()
1663 switch {
1664 case f.StreamID != 0:
1665 state, st := sc.state(f.StreamID)
1666 if state == stateIdle {
1667
1668
1669
1670
1671 return sc.countError("stream_idle", ConnectionError(ErrCodeProtocol))
1672 }
1673 if st == nil {
1674
1675
1676
1677
1678
1679 return nil
1680 }
1681 if !st.flow.add(int32(f.Increment)) {
1682 return sc.countError("bad_flow", streamError(f.StreamID, ErrCodeFlowControl))
1683 }
1684 default:
1685 if !sc.flow.add(int32(f.Increment)) {
1686 return goAwayFlowError{}
1687 }
1688 }
1689 sc.scheduleFrameWrite()
1690 return nil
1691 }
1692
1693 func (sc *serverConn) processResetStream(f *RSTStreamFrame) error {
1694 sc.serveG.check()
1695
1696 state, st := sc.state(f.StreamID)
1697 if state == stateIdle {
1698
1699
1700
1701
1702
1703 return sc.countError("reset_idle_stream", ConnectionError(ErrCodeProtocol))
1704 }
1705 if st != nil {
1706 st.cancelCtx()
1707 sc.closeStream(st, streamError(f.StreamID, f.ErrCode))
1708 }
1709 return nil
1710 }
1711
1712 func (sc *serverConn) closeStream(st *stream, err error) {
1713 sc.serveG.check()
1714 if st.state == stateIdle || st.state == stateClosed {
1715 panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
1716 }
1717 st.state = stateClosed
1718 if st.readDeadline != nil {
1719 st.readDeadline.Stop()
1720 }
1721 if st.writeDeadline != nil {
1722 st.writeDeadline.Stop()
1723 }
1724 if st.isPushed() {
1725 sc.curPushedStreams--
1726 } else {
1727 sc.curClientStreams--
1728 }
1729 delete(sc.streams, st.id)
1730 if len(sc.streams) == 0 {
1731 sc.setConnState(http.StateIdle)
1732 if sc.srv.IdleTimeout > 0 && sc.idleTimer != nil {
1733 sc.idleTimer.Reset(sc.srv.IdleTimeout)
1734 }
1735 if h1ServerKeepAlivesDisabled(sc.hs) {
1736 sc.startGracefulShutdownInternal()
1737 }
1738 }
1739 if p := st.body; p != nil {
1740
1741
1742 sc.sendWindowUpdate(nil, p.Len())
1743
1744 p.CloseWithError(err)
1745 }
1746 if e, ok := err.(StreamError); ok {
1747 if e.Cause != nil {
1748 err = e.Cause
1749 } else {
1750 err = errStreamClosed
1751 }
1752 }
1753 st.closeErr = err
1754 st.cancelCtx()
1755 st.cw.Close()
1756 sc.writeSched.CloseStream(st.id)
1757 }
1758
1759 func (sc *serverConn) processSettings(f *SettingsFrame) error {
1760 sc.serveG.check()
1761 if f.IsAck() {
1762 sc.unackedSettings--
1763 if sc.unackedSettings < 0 {
1764
1765
1766
1767 return sc.countError("ack_mystery", ConnectionError(ErrCodeProtocol))
1768 }
1769 return nil
1770 }
1771 if f.NumSettings() > 100 || f.HasDuplicates() {
1772
1773
1774
1775 return sc.countError("settings_big_or_dups", ConnectionError(ErrCodeProtocol))
1776 }
1777 if err := f.ForeachSetting(sc.processSetting); err != nil {
1778 return err
1779 }
1780
1781
1782 sc.needToSendSettingsAck = true
1783 sc.scheduleFrameWrite()
1784 return nil
1785 }
1786
1787 func (sc *serverConn) processSetting(s Setting) error {
1788 sc.serveG.check()
1789 if err := s.Valid(); err != nil {
1790 return err
1791 }
1792 if VerboseLogs {
1793 sc.vlogf("http2: server processing setting %v", s)
1794 }
1795 switch s.ID {
1796 case SettingHeaderTableSize:
1797 sc.hpackEncoder.SetMaxDynamicTableSize(s.Val)
1798 case SettingEnablePush:
1799 sc.pushEnabled = s.Val != 0
1800 case SettingMaxConcurrentStreams:
1801 sc.clientMaxStreams = s.Val
1802 case SettingInitialWindowSize:
1803 return sc.processSettingInitialWindowSize(s.Val)
1804 case SettingMaxFrameSize:
1805 sc.maxFrameSize = int32(s.Val)
1806 case SettingMaxHeaderListSize:
1807 sc.peerMaxHeaderListSize = s.Val
1808 case SettingEnableConnectProtocol:
1809
1810
1811 default:
1812
1813
1814
1815 if VerboseLogs {
1816 sc.vlogf("http2: server ignoring unknown setting %v", s)
1817 }
1818 }
1819 return nil
1820 }
1821
1822 func (sc *serverConn) processSettingInitialWindowSize(val uint32) error {
1823 sc.serveG.check()
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833 old := sc.initialStreamSendWindowSize
1834 sc.initialStreamSendWindowSize = int32(val)
1835 growth := int32(val) - old
1836 for _, st := range sc.streams {
1837 if !st.flow.add(growth) {
1838
1839
1840
1841
1842
1843
1844 return sc.countError("setting_win_size", ConnectionError(ErrCodeFlowControl))
1845 }
1846 }
1847 return nil
1848 }
1849
1850 func (sc *serverConn) processData(f *DataFrame) error {
1851 sc.serveG.check()
1852 id := f.Header().StreamID
1853
1854 data := f.Data()
1855 state, st := sc.state(id)
1856 if id == 0 || state == stateIdle {
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867 return sc.countError("data_on_idle", ConnectionError(ErrCodeProtocol))
1868 }
1869
1870
1871
1872
1873 if st == nil || state != stateOpen || st.gotTrailerHeader || st.resetQueued {
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883 if !sc.inflow.take(f.Length) {
1884 return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
1885 }
1886 sc.sendWindowUpdate(nil, int(f.Length))
1887
1888 if st != nil && st.resetQueued {
1889
1890 return nil
1891 }
1892 return sc.countError("closed", streamError(id, ErrCodeStreamClosed))
1893 }
1894 if st.body == nil {
1895 panic("internal error: should have a body in this state")
1896 }
1897
1898
1899 if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
1900 if !sc.inflow.take(f.Length) {
1901 return sc.countError("data_flow", streamError(id, ErrCodeFlowControl))
1902 }
1903 sc.sendWindowUpdate(nil, int(f.Length))
1904
1905 st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
1906
1907
1908
1909 return sc.countError("send_too_much", streamError(id, ErrCodeProtocol))
1910 }
1911 if f.Length > 0 {
1912
1913 if !takeInflows(&sc.inflow, &st.inflow, f.Length) {
1914 return sc.countError("flow_on_data_length", streamError(id, ErrCodeFlowControl))
1915 }
1916
1917 if len(data) > 0 {
1918 st.bodyBytes += int64(len(data))
1919 wrote, err := st.body.Write(data)
1920 if err != nil {
1921
1922
1923
1924 sc.sendWindowUpdate(nil, int(f.Length)-wrote)
1925 return nil
1926 }
1927 if wrote != len(data) {
1928 panic("internal error: bad Writer")
1929 }
1930 }
1931
1932
1933
1934
1935
1936
1937 pad := int32(f.Length) - int32(len(data))
1938 sc.sendWindowUpdate32(nil, pad)
1939 sc.sendWindowUpdate32(st, pad)
1940 }
1941 if f.StreamEnded() {
1942 st.endStream()
1943 }
1944 return nil
1945 }
1946
1947 func (sc *serverConn) processGoAway(f *GoAwayFrame) error {
1948 sc.serveG.check()
1949 if f.ErrCode != ErrCodeNo {
1950 sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f)
1951 } else {
1952 sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f)
1953 }
1954 sc.startGracefulShutdownInternal()
1955
1956
1957 sc.pushEnabled = false
1958 return nil
1959 }
1960
1961
1962 func (st *stream) isPushed() bool {
1963 return st.id%2 == 0
1964 }
1965
1966
1967
1968 func (st *stream) endStream() {
1969 sc := st.sc
1970 sc.serveG.check()
1971
1972 if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
1973 st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
1974 st.declBodyBytes, st.bodyBytes))
1975 } else {
1976 st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
1977 st.body.CloseWithError(io.EOF)
1978 }
1979 st.state = stateHalfClosedRemote
1980 }
1981
1982
1983
1984 func (st *stream) copyTrailersToHandlerRequest() {
1985 for k, vv := range st.trailer {
1986 if _, ok := st.reqTrailer[k]; ok {
1987
1988 st.reqTrailer[k] = vv
1989 }
1990 }
1991 }
1992
1993
1994
1995 func (st *stream) onReadTimeout() {
1996 if st.body != nil {
1997
1998
1999 st.body.CloseWithError(fmt.Errorf("%w", os.ErrDeadlineExceeded))
2000 }
2001 }
2002
2003
2004
2005 func (st *stream) onWriteTimeout() {
2006 st.sc.writeFrameFromHandler(FrameWriteRequest{write: StreamError{
2007 StreamID: st.id,
2008 Code: ErrCodeInternal,
2009 Cause: os.ErrDeadlineExceeded,
2010 }})
2011 }
2012
2013 func (sc *serverConn) processHeaders(f *MetaHeadersFrame) error {
2014 sc.serveG.check()
2015 id := f.StreamID
2016
2017
2018
2019
2020
2021 if id%2 != 1 {
2022 return sc.countError("headers_even", ConnectionError(ErrCodeProtocol))
2023 }
2024
2025
2026
2027
2028 if st := sc.streams[f.StreamID]; st != nil {
2029 if st.resetQueued {
2030
2031
2032 return nil
2033 }
2034
2035
2036
2037
2038 if st.state == stateHalfClosedRemote {
2039 return sc.countError("headers_half_closed", streamError(id, ErrCodeStreamClosed))
2040 }
2041 return st.processTrailerHeaders(f)
2042 }
2043
2044
2045
2046
2047
2048
2049 if id <= sc.maxClientStreamID {
2050 return sc.countError("stream_went_down", ConnectionError(ErrCodeProtocol))
2051 }
2052 sc.maxClientStreamID = id
2053
2054 if sc.idleTimer != nil {
2055 sc.idleTimer.Stop()
2056 }
2057
2058
2059
2060
2061
2062
2063
2064 if sc.curClientStreams+1 > sc.advMaxStreams {
2065 if sc.unackedSettings == 0 {
2066
2067 return sc.countError("over_max_streams", streamError(id, ErrCodeProtocol))
2068 }
2069
2070
2071
2072
2073
2074 return sc.countError("over_max_streams_race", streamError(id, ErrCodeRefusedStream))
2075 }
2076
2077 initialState := stateOpen
2078 if f.StreamEnded() {
2079 initialState = stateHalfClosedRemote
2080 }
2081 st := sc.newStream(id, 0, initialState)
2082
2083 if f.HasPriority() {
2084 if err := sc.checkPriority(f.StreamID, f.Priority); err != nil {
2085 return err
2086 }
2087 sc.writeSched.AdjustStream(st.id, f.Priority)
2088 }
2089
2090 rw, req, err := sc.newWriterAndRequest(st, f)
2091 if err != nil {
2092 return err
2093 }
2094 st.reqTrailer = req.Trailer
2095 if st.reqTrailer != nil {
2096 st.trailer = make(http.Header)
2097 }
2098 st.body = req.Body.(*requestBody).pipe
2099 st.declBodyBytes = req.ContentLength
2100
2101 handler := sc.handler.ServeHTTP
2102 if f.Truncated {
2103
2104 handler = handleHeaderListTooLong
2105 } else if err := checkValidHTTP2RequestHeaders(req.Header); err != nil {
2106 handler = new400Handler(err)
2107 }
2108
2109
2110
2111
2112
2113
2114
2115
2116 if sc.hs.ReadTimeout > 0 {
2117 sc.conn.SetReadDeadline(time.Time{})
2118 st.readDeadline = sc.srv.afterFunc(sc.hs.ReadTimeout, st.onReadTimeout)
2119 }
2120
2121 return sc.scheduleHandler(id, rw, req, handler)
2122 }
2123
2124 func (sc *serverConn) upgradeRequest(req *http.Request) {
2125 sc.serveG.check()
2126 id := uint32(1)
2127 sc.maxClientStreamID = id
2128 st := sc.newStream(id, 0, stateHalfClosedRemote)
2129 st.reqTrailer = req.Trailer
2130 if st.reqTrailer != nil {
2131 st.trailer = make(http.Header)
2132 }
2133 rw := sc.newResponseWriter(st, req)
2134
2135
2136
2137 if sc.hs.ReadTimeout > 0 {
2138 sc.conn.SetReadDeadline(time.Time{})
2139 }
2140
2141
2142
2143
2144 sc.curHandlers++
2145 go sc.runHandler(rw, req, sc.handler.ServeHTTP)
2146 }
2147
2148 func (st *stream) processTrailerHeaders(f *MetaHeadersFrame) error {
2149 sc := st.sc
2150 sc.serveG.check()
2151 if st.gotTrailerHeader {
2152 return sc.countError("dup_trailers", ConnectionError(ErrCodeProtocol))
2153 }
2154 st.gotTrailerHeader = true
2155 if !f.StreamEnded() {
2156 return sc.countError("trailers_not_ended", streamError(st.id, ErrCodeProtocol))
2157 }
2158
2159 if len(f.PseudoFields()) > 0 {
2160 return sc.countError("trailers_pseudo", streamError(st.id, ErrCodeProtocol))
2161 }
2162 if st.trailer != nil {
2163 for _, hf := range f.RegularFields() {
2164 key := sc.canonicalHeader(hf.Name)
2165 if !httpguts.ValidTrailerHeader(key) {
2166
2167
2168
2169 return sc.countError("trailers_bogus", streamError(st.id, ErrCodeProtocol))
2170 }
2171 st.trailer[key] = append(st.trailer[key], hf.Value)
2172 }
2173 }
2174 st.endStream()
2175 return nil
2176 }
2177
2178 func (sc *serverConn) checkPriority(streamID uint32, p PriorityParam) error {
2179 if streamID == p.StreamDep {
2180
2181
2182
2183
2184 return sc.countError("priority", streamError(streamID, ErrCodeProtocol))
2185 }
2186 return nil
2187 }
2188
2189 func (sc *serverConn) processPriority(f *PriorityFrame) error {
2190 if err := sc.checkPriority(f.StreamID, f.PriorityParam); err != nil {
2191 return err
2192 }
2193 sc.writeSched.AdjustStream(f.StreamID, f.PriorityParam)
2194 return nil
2195 }
2196
2197 func (sc *serverConn) newStream(id, pusherID uint32, state streamState) *stream {
2198 sc.serveG.check()
2199 if id == 0 {
2200 panic("internal error: cannot create stream with id 0")
2201 }
2202
2203 ctx, cancelCtx := context.WithCancel(sc.baseCtx)
2204 st := &stream{
2205 sc: sc,
2206 id: id,
2207 state: state,
2208 ctx: ctx,
2209 cancelCtx: cancelCtx,
2210 }
2211 st.cw.Init()
2212 st.flow.conn = &sc.flow
2213 st.flow.add(sc.initialStreamSendWindowSize)
2214 st.inflow.init(sc.initialStreamRecvWindowSize)
2215 if sc.hs.WriteTimeout > 0 {
2216 st.writeDeadline = sc.srv.afterFunc(sc.hs.WriteTimeout, st.onWriteTimeout)
2217 }
2218
2219 sc.streams[id] = st
2220 sc.writeSched.OpenStream(st.id, OpenStreamOptions{PusherID: pusherID})
2221 if st.isPushed() {
2222 sc.curPushedStreams++
2223 } else {
2224 sc.curClientStreams++
2225 }
2226 if sc.curOpenStreams() == 1 {
2227 sc.setConnState(http.StateActive)
2228 }
2229
2230 return st
2231 }
2232
2233 func (sc *serverConn) newWriterAndRequest(st *stream, f *MetaHeadersFrame) (*responseWriter, *http.Request, error) {
2234 sc.serveG.check()
2235
2236 rp := requestParam{
2237 method: f.PseudoValue("method"),
2238 scheme: f.PseudoValue("scheme"),
2239 authority: f.PseudoValue("authority"),
2240 path: f.PseudoValue("path"),
2241 protocol: f.PseudoValue("protocol"),
2242 }
2243
2244
2245 if disableExtendedConnectProtocol && rp.protocol != "" {
2246 return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
2247 }
2248
2249 isConnect := rp.method == "CONNECT"
2250 if isConnect {
2251 if rp.protocol == "" && (rp.path != "" || rp.scheme != "" || rp.authority == "") {
2252 return nil, nil, sc.countError("bad_connect", streamError(f.StreamID, ErrCodeProtocol))
2253 }
2254 } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") {
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265 return nil, nil, sc.countError("bad_path_method", streamError(f.StreamID, ErrCodeProtocol))
2266 }
2267
2268 rp.header = make(http.Header)
2269 for _, hf := range f.RegularFields() {
2270 rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value)
2271 }
2272 if rp.authority == "" {
2273 rp.authority = rp.header.Get("Host")
2274 }
2275 if rp.protocol != "" {
2276 rp.header.Set(":protocol", rp.protocol)
2277 }
2278
2279 rw, req, err := sc.newWriterAndRequestNoBody(st, rp)
2280 if err != nil {
2281 return nil, nil, err
2282 }
2283 bodyOpen := !f.StreamEnded()
2284 if bodyOpen {
2285 if vv, ok := rp.header["Content-Length"]; ok {
2286 if cl, err := strconv.ParseUint(vv[0], 10, 63); err == nil {
2287 req.ContentLength = int64(cl)
2288 } else {
2289 req.ContentLength = 0
2290 }
2291 } else {
2292 req.ContentLength = -1
2293 }
2294 req.Body.(*requestBody).pipe = &pipe{
2295 b: &dataBuffer{expected: req.ContentLength},
2296 }
2297 }
2298 return rw, req, nil
2299 }
2300
2301 type requestParam struct {
2302 method string
2303 scheme, authority, path string
2304 protocol string
2305 header http.Header
2306 }
2307
2308 func (sc *serverConn) newWriterAndRequestNoBody(st *stream, rp requestParam) (*responseWriter, *http.Request, error) {
2309 sc.serveG.check()
2310
2311 var tlsState *tls.ConnectionState
2312 if rp.scheme == "https" {
2313 tlsState = sc.tlsState
2314 }
2315
2316 needsContinue := httpguts.HeaderValuesContainsToken(rp.header["Expect"], "100-continue")
2317 if needsContinue {
2318 rp.header.Del("Expect")
2319 }
2320
2321 if cookies := rp.header["Cookie"]; len(cookies) > 1 {
2322 rp.header.Set("Cookie", strings.Join(cookies, "; "))
2323 }
2324
2325
2326 var trailer http.Header
2327 for _, v := range rp.header["Trailer"] {
2328 for _, key := range strings.Split(v, ",") {
2329 key = http.CanonicalHeaderKey(textproto.TrimString(key))
2330 switch key {
2331 case "Transfer-Encoding", "Trailer", "Content-Length":
2332
2333
2334 default:
2335 if trailer == nil {
2336 trailer = make(http.Header)
2337 }
2338 trailer[key] = nil
2339 }
2340 }
2341 }
2342 delete(rp.header, "Trailer")
2343
2344 var url_ *url.URL
2345 var requestURI string
2346 if rp.method == "CONNECT" && rp.protocol == "" {
2347 url_ = &url.URL{Host: rp.authority}
2348 requestURI = rp.authority
2349 } else {
2350 var err error
2351 url_, err = url.ParseRequestURI(rp.path)
2352 if err != nil {
2353 return nil, nil, sc.countError("bad_path", streamError(st.id, ErrCodeProtocol))
2354 }
2355 requestURI = rp.path
2356 }
2357
2358 body := &requestBody{
2359 conn: sc,
2360 stream: st,
2361 needsContinue: needsContinue,
2362 }
2363 req := &http.Request{
2364 Method: rp.method,
2365 URL: url_,
2366 RemoteAddr: sc.remoteAddrStr,
2367 Header: rp.header,
2368 RequestURI: requestURI,
2369 Proto: "HTTP/2.0",
2370 ProtoMajor: 2,
2371 ProtoMinor: 0,
2372 TLS: tlsState,
2373 Host: rp.authority,
2374 Body: body,
2375 Trailer: trailer,
2376 }
2377 req = req.WithContext(st.ctx)
2378
2379 rw := sc.newResponseWriter(st, req)
2380 return rw, req, nil
2381 }
2382
2383 func (sc *serverConn) newResponseWriter(st *stream, req *http.Request) *responseWriter {
2384 rws := responseWriterStatePool.Get().(*responseWriterState)
2385 bwSave := rws.bw
2386 *rws = responseWriterState{}
2387 rws.conn = sc
2388 rws.bw = bwSave
2389 rws.bw.Reset(chunkWriter{rws})
2390 rws.stream = st
2391 rws.req = req
2392 return &responseWriter{rws: rws}
2393 }
2394
2395 type unstartedHandler struct {
2396 streamID uint32
2397 rw *responseWriter
2398 req *http.Request
2399 handler func(http.ResponseWriter, *http.Request)
2400 }
2401
2402
2403
2404 func (sc *serverConn) scheduleHandler(streamID uint32, rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) error {
2405 sc.serveG.check()
2406 maxHandlers := sc.advMaxStreams
2407 if sc.curHandlers < maxHandlers {
2408 sc.curHandlers++
2409 go sc.runHandler(rw, req, handler)
2410 return nil
2411 }
2412 if len(sc.unstartedHandlers) > int(4*sc.advMaxStreams) {
2413 return sc.countError("too_many_early_resets", ConnectionError(ErrCodeEnhanceYourCalm))
2414 }
2415 sc.unstartedHandlers = append(sc.unstartedHandlers, unstartedHandler{
2416 streamID: streamID,
2417 rw: rw,
2418 req: req,
2419 handler: handler,
2420 })
2421 return nil
2422 }
2423
2424 func (sc *serverConn) handlerDone() {
2425 sc.serveG.check()
2426 sc.curHandlers--
2427 i := 0
2428 maxHandlers := sc.advMaxStreams
2429 for ; i < len(sc.unstartedHandlers); i++ {
2430 u := sc.unstartedHandlers[i]
2431 if sc.streams[u.streamID] == nil {
2432
2433 continue
2434 }
2435 if sc.curHandlers >= maxHandlers {
2436 break
2437 }
2438 sc.curHandlers++
2439 go sc.runHandler(u.rw, u.req, u.handler)
2440 sc.unstartedHandlers[i] = unstartedHandler{}
2441 }
2442 sc.unstartedHandlers = sc.unstartedHandlers[i:]
2443 if len(sc.unstartedHandlers) == 0 {
2444 sc.unstartedHandlers = nil
2445 }
2446 }
2447
2448
2449 func (sc *serverConn) runHandler(rw *responseWriter, req *http.Request, handler func(http.ResponseWriter, *http.Request)) {
2450 sc.srv.markNewGoroutine()
2451 defer sc.sendServeMsg(handlerDoneMsg)
2452 didPanic := true
2453 defer func() {
2454 rw.rws.stream.cancelCtx()
2455 if req.MultipartForm != nil {
2456 req.MultipartForm.RemoveAll()
2457 }
2458 if didPanic {
2459 e := recover()
2460 sc.writeFrameFromHandler(FrameWriteRequest{
2461 write: handlerPanicRST{rw.rws.stream.id},
2462 stream: rw.rws.stream,
2463 })
2464
2465 if e != nil && e != http.ErrAbortHandler {
2466 const size = 64 << 10
2467 buf := make([]byte, size)
2468 buf = buf[:runtime.Stack(buf, false)]
2469 sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
2470 }
2471 return
2472 }
2473 rw.handlerDone()
2474 }()
2475 handler(rw, req)
2476 didPanic = false
2477 }
2478
2479 func handleHeaderListTooLong(w http.ResponseWriter, r *http.Request) {
2480
2481
2482
2483
2484 const statusRequestHeaderFieldsTooLarge = 431
2485 w.WriteHeader(statusRequestHeaderFieldsTooLarge)
2486 io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
2487 }
2488
2489
2490
2491 func (sc *serverConn) writeHeaders(st *stream, headerData *writeResHeaders) error {
2492 sc.serveG.checkNotOn()
2493 var errc chan error
2494 if headerData.h != nil {
2495
2496
2497
2498
2499 errc = errChanPool.Get().(chan error)
2500 }
2501 if err := sc.writeFrameFromHandler(FrameWriteRequest{
2502 write: headerData,
2503 stream: st,
2504 done: errc,
2505 }); err != nil {
2506 return err
2507 }
2508 if errc != nil {
2509 select {
2510 case err := <-errc:
2511 errChanPool.Put(errc)
2512 return err
2513 case <-sc.doneServing:
2514 return errClientDisconnected
2515 case <-st.cw:
2516 return errStreamClosed
2517 }
2518 }
2519 return nil
2520 }
2521
2522
2523 func (sc *serverConn) write100ContinueHeaders(st *stream) {
2524 sc.writeFrameFromHandler(FrameWriteRequest{
2525 write: write100ContinueHeadersFrame{st.id},
2526 stream: st,
2527 })
2528 }
2529
2530
2531
2532 type bodyReadMsg struct {
2533 st *stream
2534 n int
2535 }
2536
2537
2538
2539
2540 func (sc *serverConn) noteBodyReadFromHandler(st *stream, n int, err error) {
2541 sc.serveG.checkNotOn()
2542 if n > 0 {
2543 select {
2544 case sc.bodyReadCh <- bodyReadMsg{st, n}:
2545 case <-sc.doneServing:
2546 }
2547 }
2548 }
2549
2550 func (sc *serverConn) noteBodyRead(st *stream, n int) {
2551 sc.serveG.check()
2552 sc.sendWindowUpdate(nil, n)
2553 if st.state != stateHalfClosedRemote && st.state != stateClosed {
2554
2555
2556 sc.sendWindowUpdate(st, n)
2557 }
2558 }
2559
2560
2561 func (sc *serverConn) sendWindowUpdate32(st *stream, n int32) {
2562 sc.sendWindowUpdate(st, int(n))
2563 }
2564
2565
2566 func (sc *serverConn) sendWindowUpdate(st *stream, n int) {
2567 sc.serveG.check()
2568 var streamID uint32
2569 var send int32
2570 if st == nil {
2571 send = sc.inflow.add(n)
2572 } else {
2573 streamID = st.id
2574 send = st.inflow.add(n)
2575 }
2576 if send == 0 {
2577 return
2578 }
2579 sc.writeFrame(FrameWriteRequest{
2580 write: writeWindowUpdate{streamID: streamID, n: uint32(send)},
2581 stream: st,
2582 })
2583 }
2584
2585
2586
2587 type requestBody struct {
2588 _ incomparable
2589 stream *stream
2590 conn *serverConn
2591 closeOnce sync.Once
2592 sawEOF bool
2593 pipe *pipe
2594 needsContinue bool
2595 }
2596
2597 func (b *requestBody) Close() error {
2598 b.closeOnce.Do(func() {
2599 if b.pipe != nil {
2600 b.pipe.BreakWithError(errClosedBody)
2601 }
2602 })
2603 return nil
2604 }
2605
2606 func (b *requestBody) Read(p []byte) (n int, err error) {
2607 if b.needsContinue {
2608 b.needsContinue = false
2609 b.conn.write100ContinueHeaders(b.stream)
2610 }
2611 if b.pipe == nil || b.sawEOF {
2612 return 0, io.EOF
2613 }
2614 n, err = b.pipe.Read(p)
2615 if err == io.EOF {
2616 b.sawEOF = true
2617 }
2618 if b.conn == nil && inTests {
2619 return
2620 }
2621 b.conn.noteBodyReadFromHandler(b.stream, n, err)
2622 return
2623 }
2624
2625
2626
2627
2628
2629
2630
2631 type responseWriter struct {
2632 rws *responseWriterState
2633 }
2634
2635
2636 var (
2637 _ http.CloseNotifier = (*responseWriter)(nil)
2638 _ http.Flusher = (*responseWriter)(nil)
2639 _ stringWriter = (*responseWriter)(nil)
2640 )
2641
2642 type responseWriterState struct {
2643
2644 stream *stream
2645 req *http.Request
2646 conn *serverConn
2647
2648
2649 bw *bufio.Writer
2650
2651
2652 handlerHeader http.Header
2653 snapHeader http.Header
2654 trailers []string
2655 status int
2656 wroteHeader bool
2657 sentHeader bool
2658 handlerDone bool
2659
2660 sentContentLen int64
2661 wroteBytes int64
2662
2663 closeNotifierMu sync.Mutex
2664 closeNotifierCh chan bool
2665 }
2666
2667 type chunkWriter struct{ rws *responseWriterState }
2668
2669 func (cw chunkWriter) Write(p []byte) (n int, err error) {
2670 n, err = cw.rws.writeChunk(p)
2671 if err == errStreamClosed {
2672
2673
2674 err = cw.rws.stream.closeErr
2675 }
2676 return n, err
2677 }
2678
2679 func (rws *responseWriterState) hasTrailers() bool { return len(rws.trailers) > 0 }
2680
2681 func (rws *responseWriterState) hasNonemptyTrailers() bool {
2682 for _, trailer := range rws.trailers {
2683 if _, ok := rws.handlerHeader[trailer]; ok {
2684 return true
2685 }
2686 }
2687 return false
2688 }
2689
2690
2691
2692
2693 func (rws *responseWriterState) declareTrailer(k string) {
2694 k = http.CanonicalHeaderKey(k)
2695 if !httpguts.ValidTrailerHeader(k) {
2696
2697 rws.conn.logf("ignoring invalid trailer %q", k)
2698 return
2699 }
2700 if !strSliceContains(rws.trailers, k) {
2701 rws.trailers = append(rws.trailers, k)
2702 }
2703 }
2704
2705
2706
2707
2708
2709
2710
2711 func (rws *responseWriterState) writeChunk(p []byte) (n int, err error) {
2712 if !rws.wroteHeader {
2713 rws.writeHeader(200)
2714 }
2715
2716 if rws.handlerDone {
2717 rws.promoteUndeclaredTrailers()
2718 }
2719
2720 isHeadResp := rws.req.Method == "HEAD"
2721 if !rws.sentHeader {
2722 rws.sentHeader = true
2723 var ctype, clen string
2724 if clen = rws.snapHeader.Get("Content-Length"); clen != "" {
2725 rws.snapHeader.Del("Content-Length")
2726 if cl, err := strconv.ParseUint(clen, 10, 63); err == nil {
2727 rws.sentContentLen = int64(cl)
2728 } else {
2729 clen = ""
2730 }
2731 }
2732 _, hasContentLength := rws.snapHeader["Content-Length"]
2733 if !hasContentLength && clen == "" && rws.handlerDone && bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
2734 clen = strconv.Itoa(len(p))
2735 }
2736 _, hasContentType := rws.snapHeader["Content-Type"]
2737
2738
2739 ce := rws.snapHeader.Get("Content-Encoding")
2740 hasCE := len(ce) > 0
2741 if !hasCE && !hasContentType && bodyAllowedForStatus(rws.status) && len(p) > 0 {
2742 ctype = http.DetectContentType(p)
2743 }
2744 var date string
2745 if _, ok := rws.snapHeader["Date"]; !ok {
2746
2747 date = rws.conn.srv.now().UTC().Format(http.TimeFormat)
2748 }
2749
2750 for _, v := range rws.snapHeader["Trailer"] {
2751 foreachHeaderElement(v, rws.declareTrailer)
2752 }
2753
2754
2755
2756
2757
2758
2759 if _, ok := rws.snapHeader["Connection"]; ok {
2760 v := rws.snapHeader.Get("Connection")
2761 delete(rws.snapHeader, "Connection")
2762 if v == "close" {
2763 rws.conn.startGracefulShutdown()
2764 }
2765 }
2766
2767 endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
2768 err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2769 streamID: rws.stream.id,
2770 httpResCode: rws.status,
2771 h: rws.snapHeader,
2772 endStream: endStream,
2773 contentType: ctype,
2774 contentLength: clen,
2775 date: date,
2776 })
2777 if err != nil {
2778 return 0, err
2779 }
2780 if endStream {
2781 return 0, nil
2782 }
2783 }
2784 if isHeadResp {
2785 return len(p), nil
2786 }
2787 if len(p) == 0 && !rws.handlerDone {
2788 return 0, nil
2789 }
2790
2791
2792
2793 hasNonemptyTrailers := rws.hasNonemptyTrailers()
2794 endStream := rws.handlerDone && !hasNonemptyTrailers
2795 if len(p) > 0 || endStream {
2796
2797 if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
2798 return 0, err
2799 }
2800 }
2801
2802 if rws.handlerDone && hasNonemptyTrailers {
2803 err = rws.conn.writeHeaders(rws.stream, &writeResHeaders{
2804 streamID: rws.stream.id,
2805 h: rws.handlerHeader,
2806 trailers: rws.trailers,
2807 endStream: true,
2808 })
2809 return len(p), err
2810 }
2811 return len(p), nil
2812 }
2813
2814
2815
2816
2817
2818
2819
2820
2821
2822
2823
2824
2825
2826
2827 const TrailerPrefix = "Trailer:"
2828
2829
2830
2831
2832
2833
2834
2835
2836
2837
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850 func (rws *responseWriterState) promoteUndeclaredTrailers() {
2851 for k, vv := range rws.handlerHeader {
2852 if !strings.HasPrefix(k, TrailerPrefix) {
2853 continue
2854 }
2855 trailerKey := strings.TrimPrefix(k, TrailerPrefix)
2856 rws.declareTrailer(trailerKey)
2857 rws.handlerHeader[http.CanonicalHeaderKey(trailerKey)] = vv
2858 }
2859
2860 if len(rws.trailers) > 1 {
2861 sorter := sorterPool.Get().(*sorter)
2862 sorter.SortStrings(rws.trailers)
2863 sorterPool.Put(sorter)
2864 }
2865 }
2866
2867 func (w *responseWriter) SetReadDeadline(deadline time.Time) error {
2868 st := w.rws.stream
2869 if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
2870
2871
2872 st.onReadTimeout()
2873 return nil
2874 }
2875 w.rws.conn.sendServeMsg(func(sc *serverConn) {
2876 if st.readDeadline != nil {
2877 if !st.readDeadline.Stop() {
2878
2879 return
2880 }
2881 }
2882 if deadline.IsZero() {
2883 st.readDeadline = nil
2884 } else if st.readDeadline == nil {
2885 st.readDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onReadTimeout)
2886 } else {
2887 st.readDeadline.Reset(deadline.Sub(sc.srv.now()))
2888 }
2889 })
2890 return nil
2891 }
2892
2893 func (w *responseWriter) SetWriteDeadline(deadline time.Time) error {
2894 st := w.rws.stream
2895 if !deadline.IsZero() && deadline.Before(w.rws.conn.srv.now()) {
2896
2897
2898 st.onWriteTimeout()
2899 return nil
2900 }
2901 w.rws.conn.sendServeMsg(func(sc *serverConn) {
2902 if st.writeDeadline != nil {
2903 if !st.writeDeadline.Stop() {
2904
2905 return
2906 }
2907 }
2908 if deadline.IsZero() {
2909 st.writeDeadline = nil
2910 } else if st.writeDeadline == nil {
2911 st.writeDeadline = sc.srv.afterFunc(deadline.Sub(sc.srv.now()), st.onWriteTimeout)
2912 } else {
2913 st.writeDeadline.Reset(deadline.Sub(sc.srv.now()))
2914 }
2915 })
2916 return nil
2917 }
2918
2919 func (w *responseWriter) EnableFullDuplex() error {
2920
2921 return nil
2922 }
2923
2924 func (w *responseWriter) Flush() {
2925 w.FlushError()
2926 }
2927
2928 func (w *responseWriter) FlushError() error {
2929 rws := w.rws
2930 if rws == nil {
2931 panic("Header called after Handler finished")
2932 }
2933 var err error
2934 if rws.bw.Buffered() > 0 {
2935 err = rws.bw.Flush()
2936 } else {
2937
2938
2939
2940
2941 _, err = chunkWriter{rws}.Write(nil)
2942 if err == nil {
2943 select {
2944 case <-rws.stream.cw:
2945 err = rws.stream.closeErr
2946 default:
2947 }
2948 }
2949 }
2950 return err
2951 }
2952
2953 func (w *responseWriter) CloseNotify() <-chan bool {
2954 rws := w.rws
2955 if rws == nil {
2956 panic("CloseNotify called after Handler finished")
2957 }
2958 rws.closeNotifierMu.Lock()
2959 ch := rws.closeNotifierCh
2960 if ch == nil {
2961 ch = make(chan bool, 1)
2962 rws.closeNotifierCh = ch
2963 cw := rws.stream.cw
2964 go func() {
2965 cw.Wait()
2966 ch <- true
2967 }()
2968 }
2969 rws.closeNotifierMu.Unlock()
2970 return ch
2971 }
2972
2973 func (w *responseWriter) Header() http.Header {
2974 rws := w.rws
2975 if rws == nil {
2976 panic("Header called after Handler finished")
2977 }
2978 if rws.handlerHeader == nil {
2979 rws.handlerHeader = make(http.Header)
2980 }
2981 return rws.handlerHeader
2982 }
2983
2984
2985 func checkWriteHeaderCode(code int) {
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996 if code < 100 || code > 999 {
2997 panic(fmt.Sprintf("invalid WriteHeader code %v", code))
2998 }
2999 }
3000
3001 func (w *responseWriter) WriteHeader(code int) {
3002 rws := w.rws
3003 if rws == nil {
3004 panic("WriteHeader called after Handler finished")
3005 }
3006 rws.writeHeader(code)
3007 }
3008
3009 func (rws *responseWriterState) writeHeader(code int) {
3010 if rws.wroteHeader {
3011 return
3012 }
3013
3014 checkWriteHeaderCode(code)
3015
3016
3017 if code >= 100 && code <= 199 {
3018
3019 h := rws.handlerHeader
3020
3021 _, cl := h["Content-Length"]
3022 _, te := h["Transfer-Encoding"]
3023 if cl || te {
3024 h = h.Clone()
3025 h.Del("Content-Length")
3026 h.Del("Transfer-Encoding")
3027 }
3028
3029 rws.conn.writeHeaders(rws.stream, &writeResHeaders{
3030 streamID: rws.stream.id,
3031 httpResCode: code,
3032 h: h,
3033 endStream: rws.handlerDone && !rws.hasTrailers(),
3034 })
3035
3036 return
3037 }
3038
3039 rws.wroteHeader = true
3040 rws.status = code
3041 if len(rws.handlerHeader) > 0 {
3042 rws.snapHeader = cloneHeader(rws.handlerHeader)
3043 }
3044 }
3045
3046 func cloneHeader(h http.Header) http.Header {
3047 h2 := make(http.Header, len(h))
3048 for k, vv := range h {
3049 vv2 := make([]string, len(vv))
3050 copy(vv2, vv)
3051 h2[k] = vv2
3052 }
3053 return h2
3054 }
3055
3056
3057
3058
3059
3060
3061
3062
3063
3064 func (w *responseWriter) Write(p []byte) (n int, err error) {
3065 return w.write(len(p), p, "")
3066 }
3067
3068 func (w *responseWriter) WriteString(s string) (n int, err error) {
3069 return w.write(len(s), nil, s)
3070 }
3071
3072
3073 func (w *responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) {
3074 rws := w.rws
3075 if rws == nil {
3076 panic("Write called after Handler finished")
3077 }
3078 if !rws.wroteHeader {
3079 w.WriteHeader(200)
3080 }
3081 if !bodyAllowedForStatus(rws.status) {
3082 return 0, http.ErrBodyNotAllowed
3083 }
3084 rws.wroteBytes += int64(len(dataB)) + int64(len(dataS))
3085 if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen {
3086
3087 return 0, errors.New("http2: handler wrote more than declared Content-Length")
3088 }
3089
3090 if dataB != nil {
3091 return rws.bw.Write(dataB)
3092 } else {
3093 return rws.bw.WriteString(dataS)
3094 }
3095 }
3096
3097 func (w *responseWriter) handlerDone() {
3098 rws := w.rws
3099 rws.handlerDone = true
3100 w.Flush()
3101 w.rws = nil
3102 responseWriterStatePool.Put(rws)
3103 }
3104
3105
3106 var (
3107 ErrRecursivePush = errors.New("http2: recursive push not allowed")
3108 ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS")
3109 )
3110
3111 var _ http.Pusher = (*responseWriter)(nil)
3112
3113 func (w *responseWriter) Push(target string, opts *http.PushOptions) error {
3114 st := w.rws.stream
3115 sc := st.sc
3116 sc.serveG.checkNotOn()
3117
3118
3119
3120 if st.isPushed() {
3121 return ErrRecursivePush
3122 }
3123
3124 if opts == nil {
3125 opts = new(http.PushOptions)
3126 }
3127
3128
3129 if opts.Method == "" {
3130 opts.Method = "GET"
3131 }
3132 if opts.Header == nil {
3133 opts.Header = http.Header{}
3134 }
3135 wantScheme := "http"
3136 if w.rws.req.TLS != nil {
3137 wantScheme = "https"
3138 }
3139
3140
3141 u, err := url.Parse(target)
3142 if err != nil {
3143 return err
3144 }
3145 if u.Scheme == "" {
3146 if !strings.HasPrefix(target, "/") {
3147 return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target)
3148 }
3149 u.Scheme = wantScheme
3150 u.Host = w.rws.req.Host
3151 } else {
3152 if u.Scheme != wantScheme {
3153 return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme)
3154 }
3155 if u.Host == "" {
3156 return errors.New("URL must have a host")
3157 }
3158 }
3159 for k := range opts.Header {
3160 if strings.HasPrefix(k, ":") {
3161 return fmt.Errorf("promised request headers cannot include pseudo header %q", k)
3162 }
3163
3164
3165
3166
3167 if asciiEqualFold(k, "content-length") ||
3168 asciiEqualFold(k, "content-encoding") ||
3169 asciiEqualFold(k, "trailer") ||
3170 asciiEqualFold(k, "te") ||
3171 asciiEqualFold(k, "expect") ||
3172 asciiEqualFold(k, "host") {
3173 return fmt.Errorf("promised request headers cannot include %q", k)
3174 }
3175 }
3176 if err := checkValidHTTP2RequestHeaders(opts.Header); err != nil {
3177 return err
3178 }
3179
3180
3181
3182
3183 if opts.Method != "GET" && opts.Method != "HEAD" {
3184 return fmt.Errorf("method %q must be GET or HEAD", opts.Method)
3185 }
3186
3187 msg := &startPushRequest{
3188 parent: st,
3189 method: opts.Method,
3190 url: u,
3191 header: cloneHeader(opts.Header),
3192 done: errChanPool.Get().(chan error),
3193 }
3194
3195 select {
3196 case <-sc.doneServing:
3197 return errClientDisconnected
3198 case <-st.cw:
3199 return errStreamClosed
3200 case sc.serveMsgCh <- msg:
3201 }
3202
3203 select {
3204 case <-sc.doneServing:
3205 return errClientDisconnected
3206 case <-st.cw:
3207 return errStreamClosed
3208 case err := <-msg.done:
3209 errChanPool.Put(msg.done)
3210 return err
3211 }
3212 }
3213
3214 type startPushRequest struct {
3215 parent *stream
3216 method string
3217 url *url.URL
3218 header http.Header
3219 done chan error
3220 }
3221
3222 func (sc *serverConn) startPush(msg *startPushRequest) {
3223 sc.serveG.check()
3224
3225
3226
3227
3228 if msg.parent.state != stateOpen && msg.parent.state != stateHalfClosedRemote {
3229
3230 msg.done <- errStreamClosed
3231 return
3232 }
3233
3234
3235 if !sc.pushEnabled {
3236 msg.done <- http.ErrNotSupported
3237 return
3238 }
3239
3240
3241
3242
3243 allocatePromisedID := func() (uint32, error) {
3244 sc.serveG.check()
3245
3246
3247
3248 if !sc.pushEnabled {
3249 return 0, http.ErrNotSupported
3250 }
3251
3252 if sc.curPushedStreams+1 > sc.clientMaxStreams {
3253 return 0, ErrPushLimitReached
3254 }
3255
3256
3257
3258
3259
3260 if sc.maxPushPromiseID+2 >= 1<<31 {
3261 sc.startGracefulShutdownInternal()
3262 return 0, ErrPushLimitReached
3263 }
3264 sc.maxPushPromiseID += 2
3265 promisedID := sc.maxPushPromiseID
3266
3267
3268
3269
3270
3271
3272 promised := sc.newStream(promisedID, msg.parent.id, stateHalfClosedRemote)
3273 rw, req, err := sc.newWriterAndRequestNoBody(promised, requestParam{
3274 method: msg.method,
3275 scheme: msg.url.Scheme,
3276 authority: msg.url.Host,
3277 path: msg.url.RequestURI(),
3278 header: cloneHeader(msg.header),
3279 })
3280 if err != nil {
3281
3282 panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err))
3283 }
3284
3285 sc.curHandlers++
3286 go sc.runHandler(rw, req, sc.handler.ServeHTTP)
3287 return promisedID, nil
3288 }
3289
3290 sc.writeFrame(FrameWriteRequest{
3291 write: &writePushPromise{
3292 streamID: msg.parent.id,
3293 method: msg.method,
3294 url: msg.url,
3295 h: msg.header,
3296 allocatePromisedID: allocatePromisedID,
3297 },
3298 stream: msg.parent,
3299 done: msg.done,
3300 })
3301 }
3302
3303
3304
3305 func foreachHeaderElement(v string, fn func(string)) {
3306 v = textproto.TrimString(v)
3307 if v == "" {
3308 return
3309 }
3310 if !strings.Contains(v, ",") {
3311 fn(v)
3312 return
3313 }
3314 for _, f := range strings.Split(v, ",") {
3315 if f = textproto.TrimString(f); f != "" {
3316 fn(f)
3317 }
3318 }
3319 }
3320
3321
3322 var connHeaders = []string{
3323 "Connection",
3324 "Keep-Alive",
3325 "Proxy-Connection",
3326 "Transfer-Encoding",
3327 "Upgrade",
3328 }
3329
3330
3331
3332
3333 func checkValidHTTP2RequestHeaders(h http.Header) error {
3334 for _, k := range connHeaders {
3335 if _, ok := h[k]; ok {
3336 return fmt.Errorf("request header %q is not valid in HTTP/2", k)
3337 }
3338 }
3339 te := h["Te"]
3340 if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) {
3341 return errors.New(`request header "TE" may only be "trailers" in HTTP/2`)
3342 }
3343 return nil
3344 }
3345
3346 func new400Handler(err error) http.HandlerFunc {
3347 return func(w http.ResponseWriter, r *http.Request) {
3348 http.Error(w, err.Error(), http.StatusBadRequest)
3349 }
3350 }
3351
3352
3353
3354
3355 func h1ServerKeepAlivesDisabled(hs *http.Server) bool {
3356 var x interface{} = hs
3357 type I interface {
3358 doKeepAlives() bool
3359 }
3360 if hs, ok := x.(I); ok {
3361 return !hs.doKeepAlives()
3362 }
3363 return false
3364 }
3365
3366 func (sc *serverConn) countError(name string, err error) error {
3367 if sc == nil || sc.srv == nil {
3368 return err
3369 }
3370 f := sc.countErrorFunc
3371 if f == nil {
3372 return err
3373 }
3374 var typ string
3375 var code ErrCode
3376 switch e := err.(type) {
3377 case ConnectionError:
3378 typ = "conn"
3379 code = ErrCode(e)
3380 case StreamError:
3381 typ = "stream"
3382 code = ErrCode(e.Code)
3383 default:
3384 return err
3385 }
3386 codeStr := errCodeName[code]
3387 if codeStr == "" {
3388 codeStr = strconv.Itoa(int(code))
3389 }
3390 f(fmt.Sprintf("%s_%s_%s", typ, codeStr, name))
3391 return err
3392 }
3393
View as plain text