1
2
3
4
5 package http2
6
7 import (
8 "errors"
9 "fmt"
10 "io"
11 "net/http"
12 "reflect"
13 "runtime"
14 "strconv"
15 "sync"
16 "testing"
17 "time"
18 )
19
20 func TestServer_Push_Success(t *testing.T) {
21 const (
22 mainBody = "<html>index page</html>"
23 pushedBody = "<html>pushed page</html>"
24 userAgent = "testagent"
25 cookie = "testcookie"
26 )
27
28 var stURL string
29 checkPromisedReq := func(r *http.Request, wantMethod string, wantH http.Header) error {
30 if got, want := r.Method, wantMethod; got != want {
31 return fmt.Errorf("promised Req.Method=%q, want %q", got, want)
32 }
33 if got, want := r.Header, wantH; !reflect.DeepEqual(got, want) {
34 return fmt.Errorf("promised Req.Header=%q, want %q", got, want)
35 }
36 if got, want := "https://"+r.Host, stURL; got != want {
37 return fmt.Errorf("promised Req.Host=%q, want %q", got, want)
38 }
39 if r.Body == nil {
40 return fmt.Errorf("nil Body")
41 }
42 if buf, err := io.ReadAll(r.Body); err != nil || len(buf) != 0 {
43 return fmt.Errorf("ReadAll(Body)=%q,%v, want '',nil", buf, err)
44 }
45 return nil
46 }
47
48 errc := make(chan error, 3)
49 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
50 switch r.URL.RequestURI() {
51 case "/":
52
53 opt := &http.PushOptions{
54 Header: http.Header{
55 "User-Agent": {userAgent},
56 },
57 }
58 if err := w.(http.Pusher).Push(stURL+"/pushed?get", opt); err != nil {
59 errc <- fmt.Errorf("error pushing /pushed?get: %v", err)
60 return
61 }
62
63 opt = &http.PushOptions{
64 Method: "HEAD",
65 Header: http.Header{
66 "User-Agent": {userAgent},
67 "Cookie": {cookie},
68 },
69 }
70 if err := w.(http.Pusher).Push("/pushed?head", opt); err != nil {
71 errc <- fmt.Errorf("error pushing /pushed?head: %v", err)
72 return
73 }
74 w.Header().Set("Content-Type", "text/html")
75 w.Header().Set("Content-Length", strconv.Itoa(len(mainBody)))
76 w.WriteHeader(200)
77 io.WriteString(w, mainBody)
78 errc <- nil
79
80 case "/pushed?get":
81 wantH := http.Header{}
82 wantH.Set("User-Agent", userAgent)
83 if err := checkPromisedReq(r, "GET", wantH); err != nil {
84 errc <- fmt.Errorf("/pushed?get: %v", err)
85 return
86 }
87 w.Header().Set("Content-Type", "text/html")
88 w.Header().Set("Content-Length", strconv.Itoa(len(pushedBody)))
89 w.WriteHeader(200)
90 io.WriteString(w, pushedBody)
91 errc <- nil
92
93 case "/pushed?head":
94 wantH := http.Header{}
95 wantH.Set("User-Agent", userAgent)
96 wantH.Set("Cookie", cookie)
97 if err := checkPromisedReq(r, "HEAD", wantH); err != nil {
98 errc <- fmt.Errorf("/pushed?head: %v", err)
99 return
100 }
101 w.WriteHeader(204)
102 errc <- nil
103
104 default:
105 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
106 }
107 })
108 stURL = "https://" + st.authority()
109
110
111 st.greet()
112 getSlash(st)
113 for k := 0; k < 3; k++ {
114 select {
115 case <-time.After(2 * time.Second):
116 t.Errorf("timeout waiting for handler %d to finish", k)
117 case err := <-errc:
118 if err != nil {
119 t.Fatal(err)
120 }
121 }
122 }
123
124 checkPushPromise := func(f Frame, promiseID uint32, wantH [][2]string) error {
125 pp, ok := f.(*PushPromiseFrame)
126 if !ok {
127 return fmt.Errorf("got a %T; want *PushPromiseFrame", f)
128 }
129 if !pp.HeadersEnded() {
130 return fmt.Errorf("want END_HEADERS flag in PushPromiseFrame")
131 }
132 if got, want := pp.PromiseID, promiseID; got != want {
133 return fmt.Errorf("got PromiseID %v; want %v", got, want)
134 }
135 gotH := st.decodeHeader(pp.HeaderBlockFragment())
136 if !reflect.DeepEqual(gotH, wantH) {
137 return fmt.Errorf("got promised headers %v; want %v", gotH, wantH)
138 }
139 return nil
140 }
141 checkHeaders := func(f Frame, wantH [][2]string) error {
142 hf, ok := f.(*HeadersFrame)
143 if !ok {
144 return fmt.Errorf("got a %T; want *HeadersFrame", f)
145 }
146 gotH := st.decodeHeader(hf.HeaderBlockFragment())
147 if !reflect.DeepEqual(gotH, wantH) {
148 return fmt.Errorf("got response headers %v; want %v", gotH, wantH)
149 }
150 return nil
151 }
152 checkData := func(f Frame, wantData string) error {
153 df, ok := f.(*DataFrame)
154 if !ok {
155 return fmt.Errorf("got a %T; want *DataFrame", f)
156 }
157 if gotData := string(df.Data()); gotData != wantData {
158 return fmt.Errorf("got response data %q; want %q", gotData, wantData)
159 }
160 return nil
161 }
162
163
164
165
166 expected := map[uint32][]func(Frame) error{
167 1: {
168 func(f Frame) error {
169 return checkPushPromise(f, 2, [][2]string{
170 {":method", "GET"},
171 {":scheme", "https"},
172 {":authority", st.authority()},
173 {":path", "/pushed?get"},
174 {"user-agent", userAgent},
175 })
176 },
177 func(f Frame) error {
178 return checkPushPromise(f, 4, [][2]string{
179 {":method", "HEAD"},
180 {":scheme", "https"},
181 {":authority", st.authority()},
182 {":path", "/pushed?head"},
183 {"cookie", cookie},
184 {"user-agent", userAgent},
185 })
186 },
187 func(f Frame) error {
188 return checkHeaders(f, [][2]string{
189 {":status", "200"},
190 {"content-type", "text/html"},
191 {"content-length", strconv.Itoa(len(mainBody))},
192 })
193 },
194 func(f Frame) error {
195 return checkData(f, mainBody)
196 },
197 },
198 2: {
199 func(f Frame) error {
200 return checkHeaders(f, [][2]string{
201 {":status", "200"},
202 {"content-type", "text/html"},
203 {"content-length", strconv.Itoa(len(pushedBody))},
204 })
205 },
206 func(f Frame) error {
207 return checkData(f, pushedBody)
208 },
209 },
210 4: {
211 func(f Frame) error {
212 return checkHeaders(f, [][2]string{
213 {":status", "204"},
214 })
215 },
216 },
217 }
218
219 consumed := map[uint32]int{}
220 for k := 0; len(expected) > 0; k++ {
221 f := st.readFrame()
222 if f == nil {
223 for id, left := range expected {
224 t.Errorf("stream %d: missing %d frames", id, len(left))
225 }
226 break
227 }
228 id := f.Header().StreamID
229 label := fmt.Sprintf("stream %d, frame %d", id, consumed[id])
230 if len(expected[id]) == 0 {
231 t.Fatalf("%s: unexpected frame %#+v", label, f)
232 }
233 check := expected[id][0]
234 expected[id] = expected[id][1:]
235 if len(expected[id]) == 0 {
236 delete(expected, id)
237 }
238 if err := check(f); err != nil {
239 t.Fatalf("%s: %v", label, err)
240 }
241 consumed[id]++
242 }
243 }
244
245 func TestServer_Push_SuccessNoRace(t *testing.T) {
246
247
248 errc := make(chan error, 2)
249 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
250 switch r.URL.RequestURI() {
251 case "/":
252 opt := &http.PushOptions{
253 Header: http.Header{"User-Agent": {"testagent"}},
254 }
255 if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
256 errc <- fmt.Errorf("error pushing: %v", err)
257 return
258 }
259 w.WriteHeader(200)
260 errc <- nil
261
262 case "/pushed":
263
264 r.Header.Set("User-Agent", "newagent")
265 r.Header.Set("Cookie", "cookie")
266 w.WriteHeader(200)
267 errc <- nil
268
269 default:
270 errc <- fmt.Errorf("unknown RequestURL %q", r.URL.RequestURI())
271 }
272 })
273
274
275 st.greet()
276 getSlash(st)
277 for k := 0; k < 2; k++ {
278 select {
279 case <-time.After(2 * time.Second):
280 t.Errorf("timeout waiting for handler %d to finish", k)
281 case err := <-errc:
282 if err != nil {
283 t.Fatal(err)
284 }
285 }
286 }
287 }
288
289 func TestServer_Push_RejectRecursivePush(t *testing.T) {
290
291 errc := make(chan error, 3)
292 handler := func(w http.ResponseWriter, r *http.Request) error {
293 baseURL := "https://" + r.Host
294 switch r.URL.Path {
295 case "/":
296 if err := w.(http.Pusher).Push(baseURL+"/push1", nil); err != nil {
297 return fmt.Errorf("first Push()=%v, want nil", err)
298 }
299 return nil
300
301 case "/push1":
302 if got, want := w.(http.Pusher).Push(baseURL+"/push2", nil), ErrRecursivePush; got != want {
303 return fmt.Errorf("Push()=%v, want %v", got, want)
304 }
305 return nil
306
307 default:
308 return fmt.Errorf("unexpected path: %q", r.URL.Path)
309 }
310 }
311 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
312 errc <- handler(w, r)
313 })
314 defer st.Close()
315 st.greet()
316 getSlash(st)
317 if err := <-errc; err != nil {
318 t.Errorf("First request failed: %v", err)
319 }
320 if err := <-errc; err != nil {
321 t.Errorf("Second request failed: %v", err)
322 }
323 }
324
325 func testServer_Push_RejectSingleRequest(t *testing.T, doPush func(http.Pusher, *http.Request) error, settings ...Setting) {
326
327 errc := make(chan error, 2)
328 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
329 errc <- doPush(w.(http.Pusher), r)
330 })
331 defer st.Close()
332 st.greet()
333 if err := st.fr.WriteSettings(settings...); err != nil {
334 st.t.Fatalf("WriteSettings: %v", err)
335 }
336 st.wantSettingsAck()
337 getSlash(st)
338 if err := <-errc; err != nil {
339 t.Error(err)
340 }
341
342 st.wantHeaders(wantHeader{
343 streamID: 1,
344 endStream: true,
345 })
346 }
347
348 func TestServer_Push_RejectIfDisabled(t *testing.T) {
349 testServer_Push_RejectSingleRequest(t,
350 func(p http.Pusher, r *http.Request) error {
351 if got, want := p.Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
352 return fmt.Errorf("Push()=%v, want %v", got, want)
353 }
354 return nil
355 },
356 Setting{SettingEnablePush, 0})
357 }
358
359 func TestServer_Push_RejectWhenNoConcurrentStreams(t *testing.T) {
360 testServer_Push_RejectSingleRequest(t,
361 func(p http.Pusher, r *http.Request) error {
362 if got, want := p.Push("https://"+r.Host+"/pushed", nil), ErrPushLimitReached; got != want {
363 return fmt.Errorf("Push()=%v, want %v", got, want)
364 }
365 return nil
366 },
367 Setting{SettingMaxConcurrentStreams, 0})
368 }
369
370 func TestServer_Push_RejectWrongScheme(t *testing.T) {
371 testServer_Push_RejectSingleRequest(t,
372 func(p http.Pusher, r *http.Request) error {
373 if err := p.Push("http://"+r.Host+"/pushed", nil); err == nil {
374 return errors.New("Push() should have failed (push target URL is http)")
375 }
376 return nil
377 })
378 }
379
380 func TestServer_Push_RejectMissingHost(t *testing.T) {
381 testServer_Push_RejectSingleRequest(t,
382 func(p http.Pusher, r *http.Request) error {
383 if err := p.Push("https:pushed", nil); err == nil {
384 return errors.New("Push() should have failed (push target URL missing host)")
385 }
386 return nil
387 })
388 }
389
390 func TestServer_Push_RejectRelativePath(t *testing.T) {
391 testServer_Push_RejectSingleRequest(t,
392 func(p http.Pusher, r *http.Request) error {
393 if err := p.Push("../test", nil); err == nil {
394 return errors.New("Push() should have failed (push target is a relative path)")
395 }
396 return nil
397 })
398 }
399
400 func TestServer_Push_RejectForbiddenMethod(t *testing.T) {
401 testServer_Push_RejectSingleRequest(t,
402 func(p http.Pusher, r *http.Request) error {
403 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Method: "POST"}); err == nil {
404 return errors.New("Push() should have failed (cannot promise a POST)")
405 }
406 return nil
407 })
408 }
409
410 func TestServer_Push_RejectForbiddenHeader(t *testing.T) {
411 testServer_Push_RejectSingleRequest(t,
412 func(p http.Pusher, r *http.Request) error {
413 header := http.Header{
414 "Content-Length": {"10"},
415 "Content-Encoding": {"gzip"},
416 "Trailer": {"Foo"},
417 "Te": {"trailers"},
418 "Host": {"test.com"},
419 ":authority": {"test.com"},
420 }
421 if err := p.Push("https://"+r.Host+"/pushed", &http.PushOptions{Header: header}); err == nil {
422 return errors.New("Push() should have failed (forbidden headers)")
423 }
424 return nil
425 })
426 }
427
428 func TestServer_Push_StateTransitions(t *testing.T) {
429 const body = "foo"
430
431 gotPromise := make(chan bool)
432 finishedPush := make(chan bool)
433
434 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
435 switch r.URL.RequestURI() {
436 case "/":
437 if err := w.(http.Pusher).Push("/pushed", nil); err != nil {
438 t.Errorf("Push error: %v", err)
439 }
440
441
442 <-finishedPush
443 case "/pushed":
444 <-gotPromise
445 }
446 w.Header().Set("Content-Type", "text/html")
447 w.Header().Set("Content-Length", strconv.Itoa(len(body)))
448 w.WriteHeader(200)
449 io.WriteString(w, body)
450 })
451 defer st.Close()
452
453 st.greet()
454 if st.stream(2) != nil {
455 t.Fatal("stream 2 should be empty")
456 }
457 if got, want := st.streamState(2), stateIdle; got != want {
458 t.Fatalf("streamState(2)=%v, want %v", got, want)
459 }
460 getSlash(st)
461
462 _ = readFrame[*PushPromiseFrame](t, st)
463 if got, want := st.streamState(2), stateHalfClosedRemote; got != want {
464 t.Fatalf("streamState(2)=%v, want %v", got, want)
465 }
466
467
468
469
470 close(gotPromise)
471 st.wantHeaders(wantHeader{
472 streamID: 2,
473 endStream: false,
474 })
475 if got, want := st.streamState(2), stateClosed; got != want {
476 t.Fatalf("streamState(2)=%v, want %v", got, want)
477 }
478 close(finishedPush)
479 }
480
481 func TestServer_Push_RejectAfterGoAway(t *testing.T) {
482 var readyOnce sync.Once
483 ready := make(chan struct{})
484 errc := make(chan error, 2)
485 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
486 <-ready
487 if got, want := w.(http.Pusher).Push("https://"+r.Host+"/pushed", nil), http.ErrNotSupported; got != want {
488 errc <- fmt.Errorf("Push()=%v, want %v", got, want)
489 }
490 errc <- nil
491 })
492 defer st.Close()
493 st.greet()
494 getSlash(st)
495
496
497 st.fr.WriteGoAway(1, ErrCodeNo, nil)
498 go func() {
499 for {
500 select {
501 case <-ready:
502 return
503 default:
504 if runtime.GOARCH == "wasm" {
505
506 runtime.Gosched()
507 }
508 }
509 st.sc.serveMsgCh <- func(loopNum int) {
510 if !st.sc.pushEnabled {
511 readyOnce.Do(func() { close(ready) })
512 }
513 }
514 }
515 }()
516 if err := <-errc; err != nil {
517 t.Error(err)
518 }
519 }
520
521 func TestServer_Push_Underflow(t *testing.T) {
522
523
524 st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
525 switch r.URL.RequestURI() {
526 case "/":
527 opt := &http.PushOptions{
528 Header: http.Header{"User-Agent": {"testagent"}},
529 }
530 if err := w.(http.Pusher).Push("/pushed", opt); err != nil {
531 t.Errorf("error pushing: %v", err)
532 }
533 w.WriteHeader(200)
534 case "/pushed":
535 r.Header.Set("User-Agent", "newagent")
536 r.Header.Set("Cookie", "cookie")
537 w.WriteHeader(200)
538 default:
539 t.Errorf("unknown RequestURL %q", r.URL.RequestURI())
540 }
541 })
542
543 st.greet()
544 const numRequests = 4
545 for i := 0; i < numRequests; i++ {
546 st.writeHeaders(HeadersFrameParam{
547 StreamID: uint32(1 + i*2),
548 BlockFragment: st.encodeHeader(),
549 EndStream: true,
550 EndHeaders: true,
551 })
552 }
553
554 numPushPromises := 0
555 numHeaders := 0
556 for numHeaders < numRequests*2 || numPushPromises < numRequests {
557 f := st.readFrame()
558 if f == nil {
559 st.t.Fatal("conn is idle, want frame")
560 }
561 switch f := f.(type) {
562 case *HeadersFrame:
563 if !f.Flags.Has(FlagHeadersEndStream) {
564 t.Fatalf("got HEADERS frame with no END_STREAM, expected END_STREAM: %v", f)
565 }
566 numHeaders++
567 case *PushPromiseFrame:
568 numPushPromises++
569 }
570 }
571 }
572
View as plain text