1
2
3
4
5 package handlers
6
7 import (
8 "net/http"
9 "net/http/httptest"
10 "net/url"
11 "strings"
12 "testing"
13 )
14
15 const (
16 ok = "ok\n"
17 notAllowed = "Method not allowed\n"
18 )
19
20 var okHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
21 w.Write([]byte(ok))
22 })
23
24 func newRequest(method, url string) *http.Request {
25 req, err := http.NewRequest(method, url, nil)
26 if err != nil {
27 panic(err)
28 }
29 return req
30 }
31
32 func TestMethodHandler(t *testing.T) {
33 tests := []struct {
34 req *http.Request
35 handler http.Handler
36 code int
37 allow string
38 body string
39 }{
40
41 {newRequest("GET", "/foo"), MethodHandler{}, http.StatusMethodNotAllowed, "", notAllowed},
42 {newRequest("OPTIONS", "/foo"), MethodHandler{}, http.StatusOK, "", ""},
43
44
45 {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler}, http.StatusOK, "", ok},
46 {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler}, http.StatusMethodNotAllowed, "GET", notAllowed},
47
48
49 {newRequest("GET", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok},
50 {newRequest("POST", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "", ok},
51 {newRequest("DELETE", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusMethodNotAllowed, "GET, POST", notAllowed},
52 {newRequest("OPTIONS", "/foo"), MethodHandler{"GET": okHandler, "POST": okHandler}, http.StatusOK, "GET, POST", ""},
53
54
55 {newRequest("OPTIONS", "/foo"), MethodHandler{"OPTIONS": okHandler}, http.StatusOK, "", ok},
56 }
57
58 for i, test := range tests {
59 rec := httptest.NewRecorder()
60 test.handler.ServeHTTP(rec, test.req)
61 if rec.Code != test.code {
62 t.Fatalf("%d: wrong code, got %d want %d", i, rec.Code, test.code)
63 }
64 if allow := rec.HeaderMap.Get("Allow"); allow != test.allow {
65 t.Fatalf("%d: wrong Allow, got %s want %s", i, allow, test.allow)
66 }
67 if body := rec.Body.String(); body != test.body {
68 t.Fatalf("%d: wrong body, got %q want %q", i, body, test.body)
69 }
70 }
71 }
72
73 func TestContentTypeHandler(t *testing.T) {
74 tests := []struct {
75 Method string
76 AllowContentTypes []string
77 ContentType string
78 Code int
79 }{
80 {"POST", []string{"application/json"}, "application/json", http.StatusOK},
81 {"POST", []string{"application/json", "application/xml"}, "application/json", http.StatusOK},
82 {"POST", []string{"application/json"}, "application/json; charset=utf-8", http.StatusOK},
83 {"POST", []string{"application/json"}, "application/json+xxx", http.StatusUnsupportedMediaType},
84 {"POST", []string{"application/json"}, "text/plain", http.StatusUnsupportedMediaType},
85 {"GET", []string{"application/json"}, "", http.StatusOK},
86 {"GET", []string{}, "", http.StatusOK},
87 }
88 for _, test := range tests {
89 r, err := http.NewRequest(test.Method, "/", nil)
90 if err != nil {
91 t.Error(err)
92 continue
93 }
94
95 h := ContentTypeHandler(okHandler, test.AllowContentTypes...)
96 r.Header.Set("Content-Type", test.ContentType)
97 w := httptest.NewRecorder()
98 h.ServeHTTP(w, r)
99 if w.Code != test.Code {
100 t.Errorf("expected %d, got %d", test.Code, w.Code)
101 }
102 }
103 }
104
105 func TestHTTPMethodOverride(t *testing.T) {
106 var tests = []struct {
107 Method string
108 OverrideMethod string
109 ExpectedMethod string
110 }{
111 {"POST", "PUT", "PUT"},
112 {"POST", "PATCH", "PATCH"},
113 {"POST", "DELETE", "DELETE"},
114 {"PUT", "DELETE", "PUT"},
115 {"GET", "GET", "GET"},
116 {"HEAD", "HEAD", "HEAD"},
117 {"GET", "PUT", "GET"},
118 {"HEAD", "DELETE", "HEAD"},
119 }
120
121 for _, test := range tests {
122 h := HTTPMethodOverrideHandler(okHandler)
123 reqs := make([]*http.Request, 0, 2)
124
125 rHeader, err := http.NewRequest(test.Method, "/", nil)
126 if err != nil {
127 t.Error(err)
128 }
129 rHeader.Header.Set(HTTPMethodOverrideHeader, test.OverrideMethod)
130 reqs = append(reqs, rHeader)
131
132 f := url.Values{HTTPMethodOverrideFormKey: []string{test.OverrideMethod}}
133 rForm, err := http.NewRequest(test.Method, "/", strings.NewReader(f.Encode()))
134 if err != nil {
135 t.Error(err)
136 }
137 rForm.Header.Set("Content-Type", "application/x-www-form-urlencoded")
138 reqs = append(reqs, rForm)
139
140 for _, r := range reqs {
141 w := httptest.NewRecorder()
142 h.ServeHTTP(w, r)
143 if r.Method != test.ExpectedMethod {
144 t.Errorf("Expected %s, got %s", test.ExpectedMethod, r.Method)
145 }
146 }
147 }
148 }
149
View as plain text