1
16
17
18 package test
19
20 import (
21 "bytes"
22 "errors"
23 "io"
24 "strings"
25 "testing"
26 "time"
27
28 "golang.org/x/net/http2"
29 "golang.org/x/net/http2/hpack"
30 )
31
32
33
34
35
36
37
38
39
40
41
42 type serverTester struct {
43 cc io.ReadWriteCloser
44 t testing.TB
45 fr *http2.Framer
46
47
48 headerBuf bytes.Buffer
49 hpackEnc *hpack.Encoder
50
51
52 frc chan http2.Frame
53 frErrc chan error
54 }
55
56 func newServerTesterFromConn(t testing.TB, cc io.ReadWriteCloser) *serverTester {
57 st := &serverTester{
58 t: t,
59 cc: cc,
60 frc: make(chan http2.Frame, 1),
61 frErrc: make(chan error, 1),
62 }
63 st.hpackEnc = hpack.NewEncoder(&st.headerBuf)
64 st.fr = http2.NewFramer(cc, cc)
65 st.fr.ReadMetaHeaders = hpack.NewDecoder(4096 , nil)
66
67 return st
68 }
69
70 func (st *serverTester) readFrame() (http2.Frame, error) {
71 go func() {
72 fr, err := st.fr.ReadFrame()
73 if err != nil {
74 st.frErrc <- err
75 } else {
76 st.frc <- fr
77 }
78 }()
79 t := time.NewTimer(2 * time.Second)
80 defer t.Stop()
81 select {
82 case f := <-st.frc:
83 return f, nil
84 case err := <-st.frErrc:
85 return nil, err
86 case <-t.C:
87 return nil, errors.New("timeout waiting for frame")
88 }
89 }
90
91
92
93 func (st *serverTester) greet() {
94 st.writePreface()
95 st.writeInitialSettings()
96 st.wantSettings()
97 st.writeSettingsAck()
98 for {
99 f, err := st.readFrame()
100 if err != nil {
101 st.t.Fatal(err)
102 }
103 switch f := f.(type) {
104 case *http2.WindowUpdateFrame:
105
106
107
108 case *http2.SettingsFrame:
109 if f.IsAck() {
110 return
111 }
112 st.t.Fatalf("during greet, got non-ACK settings frame")
113 default:
114 st.t.Fatalf("during greet, unexpected frame type %T", f)
115 }
116 }
117 }
118
119 func (st *serverTester) writePreface() {
120 n, err := st.cc.Write([]byte(http2.ClientPreface))
121 if err != nil {
122 st.t.Fatalf("Error writing client preface: %v", err)
123 }
124 if n != len(http2.ClientPreface) {
125 st.t.Fatalf("Writing client preface, wrote %d bytes; want %d", n, len(http2.ClientPreface))
126 }
127 }
128
129 func (st *serverTester) writeInitialSettings() {
130 if err := st.fr.WriteSettings(); err != nil {
131 st.t.Fatalf("Error writing initial SETTINGS frame from client to server: %v", err)
132 }
133 }
134
135 func (st *serverTester) writeSettingsAck() {
136 if err := st.fr.WriteSettingsAck(); err != nil {
137 st.t.Fatalf("Error writing ACK of server's SETTINGS: %v", err)
138 }
139 }
140
141 func (st *serverTester) wantGoAway(errCode http2.ErrCode) *http2.GoAwayFrame {
142 f, err := st.readFrame()
143 if err != nil {
144 st.t.Fatalf("Error while expecting an RST frame: %v", err)
145 }
146 gaf, ok := f.(*http2.GoAwayFrame)
147 if !ok {
148 st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
149 }
150 if gaf.ErrCode != errCode {
151 st.t.Fatalf("expected GOAWAY error code '%v', got '%v'", errCode.String(), gaf.ErrCode.String())
152 }
153 return gaf
154 }
155
156 func (st *serverTester) wantPing() *http2.PingFrame {
157 f, err := st.readFrame()
158 if err != nil {
159 st.t.Fatalf("Error while expecting an RST frame: %v", err)
160 }
161 pf, ok := f.(*http2.PingFrame)
162 if !ok {
163 st.t.Fatalf("got a %T; want *http2.GoAwayFrame", f)
164 }
165 return pf
166 }
167
168 func (st *serverTester) wantRSTStream(errCode http2.ErrCode) *http2.RSTStreamFrame {
169 f, err := st.readFrame()
170 if err != nil {
171 st.t.Fatalf("Error while expecting an RST frame: %v", err)
172 }
173 rf, ok := f.(*http2.RSTStreamFrame)
174 if !ok {
175 st.t.Fatalf("got a %T; want *http2.RSTStreamFrame", f)
176 }
177 if rf.ErrCode != errCode {
178 st.t.Fatalf("expected RST error code '%v', got '%v'", errCode.String(), rf.ErrCode.String())
179 }
180 return rf
181 }
182
183 func (st *serverTester) wantSettings() *http2.SettingsFrame {
184 f, err := st.readFrame()
185 if err != nil {
186 st.t.Fatalf("Error while expecting a SETTINGS frame: %v", err)
187 }
188 sf, ok := f.(*http2.SettingsFrame)
189 if !ok {
190 st.t.Fatalf("got a %T; want *SettingsFrame", f)
191 }
192 return sf
193 }
194
195
196 func (st *serverTester) wantAnyFrame() http2.Frame {
197 f, err := st.fr.ReadFrame()
198 if err != nil {
199 st.t.Fatal(err)
200 }
201 return f
202 }
203
204 func (st *serverTester) encodeHeaderField(k, v string) {
205 err := st.hpackEnc.WriteField(hpack.HeaderField{Name: k, Value: v})
206 if err != nil {
207 st.t.Fatalf("HPACK encoding error for %q/%q: %v", k, v, err)
208 }
209 }
210
211
212
213
214
215 func (st *serverTester) encodeHeader(headers ...string) []byte {
216 if len(headers)%2 == 1 {
217 panic("odd number of kv args")
218 }
219
220 st.headerBuf.Reset()
221
222 if len(headers) == 0 {
223
224
225 st.encodeHeaderField(":method", "GET")
226 st.encodeHeaderField(":path", "/")
227 st.encodeHeaderField(":scheme", "https")
228 return st.headerBuf.Bytes()
229 }
230
231 if len(headers) == 2 && headers[0] == ":method" {
232
233 st.encodeHeaderField(":method", headers[1])
234 st.encodeHeaderField(":path", "/")
235 st.encodeHeaderField(":scheme", "https")
236 return st.headerBuf.Bytes()
237 }
238
239 pseudoCount := map[string]int{}
240 keys := []string{":method", ":path", ":scheme"}
241 vals := map[string][]string{
242 ":method": {"GET"},
243 ":path": {"/"},
244 ":scheme": {"https"},
245 }
246 for len(headers) > 0 {
247 k, v := headers[0], headers[1]
248 headers = headers[2:]
249 if _, ok := vals[k]; !ok {
250 keys = append(keys, k)
251 }
252 if strings.HasPrefix(k, ":") {
253 pseudoCount[k]++
254 if pseudoCount[k] == 1 {
255 vals[k] = []string{v}
256 } else {
257
258 vals[k] = append(vals[k], v)
259 }
260 } else {
261 vals[k] = append(vals[k], v)
262 }
263 }
264 for _, k := range keys {
265 for _, v := range vals[k] {
266 st.encodeHeaderField(k, v)
267 }
268 }
269 return st.headerBuf.Bytes()
270 }
271
272 func (st *serverTester) writeHeadersGRPC(streamID uint32, path string, endStream bool) {
273 st.writeHeaders(http2.HeadersFrameParam{
274 StreamID: streamID,
275 BlockFragment: st.encodeHeader(
276 ":method", "POST",
277 ":path", path,
278 "content-type", "application/grpc",
279 "te", "trailers",
280 ),
281 EndStream: endStream,
282 EndHeaders: true,
283 })
284 }
285
286 func (st *serverTester) writeHeaders(p http2.HeadersFrameParam) {
287 if err := st.fr.WriteHeaders(p); err != nil {
288 st.t.Fatalf("Error writing HEADERS: %v", err)
289 }
290 }
291
292 func (st *serverTester) writeData(streamID uint32, endStream bool, data []byte) {
293 if err := st.fr.WriteData(streamID, endStream, data); err != nil {
294 st.t.Fatalf("Error writing DATA: %v", err)
295 }
296 }
297
298 func (st *serverTester) writeRSTStream(streamID uint32, code http2.ErrCode) {
299 if err := st.fr.WriteRSTStream(streamID, code); err != nil {
300 st.t.Fatalf("Error writing RST_STREAM: %v", err)
301 }
302 }
303
304 func (st *serverTester) writePing(ack bool, data [8]byte) {
305 if err := st.fr.WritePing(ack, data); err != nil {
306 st.t.Fatalf("Error writing PING: %v", err)
307 }
308 }
309
View as plain text