1
2
3
4
5 package h2c
6
7 import (
8 "context"
9 "crypto/tls"
10 "fmt"
11 "io"
12 "log"
13 "net"
14 "net/http"
15 "net/http/httptest"
16 "strings"
17 "testing"
18
19 "golang.org/x/net/http2"
20 )
21
22 func ExampleNewHandler() {
23 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 fmt.Fprint(w, "Hello world")
25 })
26 h2s := &http2.Server{
27
28 }
29 h1s := &http.Server{
30 Addr: ":8080",
31 Handler: NewHandler(handler, h2s),
32 }
33 log.Fatal(h1s.ListenAndServe())
34 }
35
36 func TestContext(t *testing.T) {
37 baseCtx := context.WithValue(context.Background(), "testkey", "testvalue")
38
39 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
40 if r.ProtoMajor != 2 {
41 t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
42 }
43 if r.Context().Value("testkey") != "testvalue" {
44 t.Errorf("Request doesn't have expected base context: %v", r.Context())
45 }
46 fmt.Fprint(w, "Hello world")
47 })
48
49 h2s := &http2.Server{}
50 h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
51 h1s.Config.BaseContext = func(_ net.Listener) context.Context {
52 return baseCtx
53 }
54 h1s.Start()
55 defer h1s.Close()
56
57 client := &http.Client{
58 Transport: &http2.Transport{
59 AllowHTTP: true,
60 DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
61 return net.Dial(network, addr)
62 },
63 },
64 }
65
66 resp, err := client.Get(h1s.URL)
67 if err != nil {
68 t.Fatal(err)
69 }
70 _, err = io.ReadAll(resp.Body)
71 if err != nil {
72 t.Fatal(err)
73 }
74 if err := resp.Body.Close(); err != nil {
75 t.Fatal(err)
76 }
77 }
78
79 func TestPropagation(t *testing.T) {
80 var (
81 server *http.Server
82
83 headerSize = 1 << 11
84 headerLimit = 1 << 10
85 )
86
87 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
88 if r.ProtoMajor != 2 {
89 t.Errorf("Request wasn't handled by h2c. Got ProtoMajor=%v", r.ProtoMajor)
90 }
91 if r.Context().Value(http.ServerContextKey).(*http.Server) != server {
92 t.Errorf("Request doesn't have expected http server: %v", r.Context())
93 }
94 if len(r.Header.Get("Long-Header")) != headerSize {
95 t.Errorf("Request doesn't have expected http header length: %v", len(r.Header.Get("Long-Header")))
96 }
97 fmt.Fprint(w, "Hello world")
98 })
99
100 h2s := &http2.Server{}
101 h1s := httptest.NewUnstartedServer(NewHandler(handler, h2s))
102
103 server = h1s.Config
104 server.MaxHeaderBytes = headerLimit
105 server.ConnState = func(conn net.Conn, state http.ConnState) {
106 t.Logf("server conn state: conn %s -> %s, status changed to %s", conn.RemoteAddr(), conn.LocalAddr(), state)
107 }
108
109 h1s.Start()
110 defer h1s.Close()
111
112 client := &http.Client{
113 Transport: &http2.Transport{
114 AllowHTTP: true,
115 DialTLS: func(network, addr string, _ *tls.Config) (net.Conn, error) {
116 conn, err := net.Dial(network, addr)
117 if conn != nil {
118 t.Logf("client dial tls: %s -> %s", conn.RemoteAddr(), conn.LocalAddr())
119 }
120 return conn, err
121 },
122 },
123 }
124
125 req, err := http.NewRequest("GET", h1s.URL, nil)
126 if err != nil {
127 t.Fatal(err)
128 }
129
130 req.Header.Set("Long-Header", strings.Repeat("A", headerSize))
131
132 _, err = client.Do(req)
133 if err == nil {
134 t.Fatal("expected server err, got nil")
135 }
136 }
137
138 func TestMaxBytesHandler(t *testing.T) {
139 const bodyLimit = 10
140 handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141 t.Errorf("got request, expected to be blocked by body limit")
142 })
143
144 h2s := &http2.Server{}
145 h1s := httptest.NewUnstartedServer(http.MaxBytesHandler(NewHandler(handler, h2s), bodyLimit))
146 h1s.Start()
147 defer h1s.Close()
148
149
150 body := "0123456789abcdef"
151 req, err := http.NewRequest("POST", h1s.URL, struct{ io.Reader }{strings.NewReader(body)})
152 if err != nil {
153 t.Fatal(err)
154 }
155 req.Header.Set("Http2-Settings", "")
156 req.Header.Set("Upgrade", "h2c")
157 req.Header.Set("Connection", "Upgrade, HTTP2-Settings")
158
159 resp, err := h1s.Client().Do(req)
160 if err != nil {
161 t.Fatal(err)
162 }
163 defer resp.Body.Close()
164 _, err = io.ReadAll(resp.Body)
165 if err != nil {
166 t.Fatal(err)
167 }
168 if got, want := resp.StatusCode, http.StatusInternalServerError; got != want {
169 t.Errorf("resp.StatusCode = %v, want %v", got, want)
170 }
171 }
172
View as plain text