1 package httpsnoop
2
3 import (
4 "bytes"
5 "io"
6 "io/ioutil"
7 "net/http"
8 "net/http/httptest"
9 "testing"
10 )
11
12 func TestWrap_integration(t *testing.T) {
13 tests := []struct {
14 Name string
15 Handler http.Handler
16 Hooks Hooks
17 WantCode int
18 WantBody []byte
19 }{
20 {
21 Name: "WriteHeader (no hook)",
22 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23 w.WriteHeader(http.StatusNotFound)
24 }),
25 WantCode: http.StatusNotFound,
26 },
27 {
28 Name: "WriteHeader",
29 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
30 w.WriteHeader(http.StatusNotFound)
31 }),
32 Hooks: Hooks{
33 WriteHeader: func(next WriteHeaderFunc) WriteHeaderFunc {
34 return func(code int) {
35 if code != http.StatusNotFound {
36 t.Errorf("got=%d want=%d", code, http.StatusNotFound)
37 }
38 next(http.StatusForbidden)
39 }
40 },
41 },
42 WantCode: http.StatusForbidden,
43 },
44
45 {
46 Name: "Write (no hook)",
47 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
48 w.Write([]byte("foo"))
49 }),
50 WantCode: http.StatusOK,
51 WantBody: []byte("foo"),
52 },
53 {
54 Name: "Write (rewrite hook)",
55 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
56 if n, err := w.Write([]byte("foo")); err != nil {
57 t.Errorf("got=%s", err)
58 } else if got, want := n, len("foobar"); got != want {
59 t.Errorf("got=%d want=%d", got, want)
60 }
61 }),
62 Hooks: Hooks{
63 Write: func(next WriteFunc) WriteFunc {
64 return func(p []byte) (int, error) {
65 if string(p) != "foo" {
66 t.Errorf("%s", p)
67 }
68 return next([]byte("foobar"))
69 }
70 },
71 },
72 WantCode: http.StatusOK,
73 WantBody: []byte("foobar"),
74 },
75 {
76 Name: "Write (error hook)",
77 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78 if n, err := w.Write([]byte("foo")); n != 0 {
79 t.Errorf("got=%d want=%d", n, 0)
80 } else if err != io.EOF {
81 t.Errorf("got=%s want=%s", err, io.EOF)
82 }
83 }),
84 Hooks: Hooks{
85 Write: func(next WriteFunc) WriteFunc {
86 return func(p []byte) (int, error) {
87 if string(p) != "foo" {
88 t.Errorf("%s", p)
89 }
90 return 0, io.EOF
91 }
92 },
93 },
94 WantCode: http.StatusOK,
95 },
96 }
97
98 for _, test := range tests {
99 func() {
100 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
101 sw := Wrap(w, test.Hooks)
102 test.Handler.ServeHTTP(sw, r)
103 })
104 s := httptest.NewServer(h)
105 defer s.Close()
106 res, err := http.Get(s.URL)
107 if err != nil {
108 t.Fatal(err)
109 }
110 defer res.Body.Close()
111 gotBody, err := ioutil.ReadAll(res.Body)
112 if res.StatusCode != test.WantCode {
113 t.Errorf("got=%d want=%d", res.StatusCode, test.WantCode)
114 } else if !bytes.Equal(gotBody, test.WantBody) {
115 t.Errorf("got=%s want=%s", gotBody, test.WantBody)
116 }
117 }()
118 }
119 }
120
View as plain text