1 package middleware
2
3 import (
4 "net/http"
5 "net/http/httptest"
6 "testing"
7
8 "github.com/go-chi/chi"
9 )
10
11 func TestStripSlashes(t *testing.T) {
12 r := chi.NewRouter()
13
14
15
16 r.Use(StripSlashes)
17
18 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
19 w.WriteHeader(404)
20 w.Write([]byte("nothing here"))
21 })
22
23 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
24 w.Write([]byte("root"))
25 })
26
27 r.Route("/accounts/{accountID}", func(r chi.Router) {
28 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
29 accountID := chi.URLParam(r, "accountID")
30 w.Write([]byte(accountID))
31 })
32 })
33
34 ts := httptest.NewServer(r)
35 defer ts.Close()
36
37 if _, resp := testRequest(t, ts, "GET", "/", nil); resp != "root" {
38 t.Fatalf(resp)
39 }
40 if _, resp := testRequest(t, ts, "GET", "//", nil); resp != "root" {
41 t.Fatalf(resp)
42 }
43 if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "admin" {
44 t.Fatalf(resp)
45 }
46 if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "admin" {
47 t.Fatalf(resp)
48 }
49 if _, resp := testRequest(t, ts, "GET", "/nothing-here", nil); resp != "nothing here" {
50 t.Fatalf(resp)
51 }
52 }
53
54 func TestStripSlashesInRoute(t *testing.T) {
55 r := chi.NewRouter()
56
57 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
58 w.WriteHeader(404)
59 w.Write([]byte("nothing here"))
60 })
61
62 r.Get("/hi", func(w http.ResponseWriter, r *http.Request) {
63 w.Write([]byte("hi"))
64 })
65
66 r.Route("/accounts/{accountID}", func(r chi.Router) {
67 r.Use(StripSlashes)
68 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
69 w.Write([]byte("accounts index"))
70 })
71 r.Get("/query", func(w http.ResponseWriter, r *http.Request) {
72 accountID := chi.URLParam(r, "accountID")
73 w.Write([]byte(accountID))
74 })
75 })
76
77 ts := httptest.NewServer(r)
78 defer ts.Close()
79
80 if _, resp := testRequest(t, ts, "GET", "/hi", nil); resp != "hi" {
81 t.Fatalf(resp)
82 }
83 if _, resp := testRequest(t, ts, "GET", "/hi/", nil); resp != "nothing here" {
84 t.Fatalf(resp)
85 }
86 if _, resp := testRequest(t, ts, "GET", "/accounts/admin", nil); resp != "accounts index" {
87 t.Fatalf(resp)
88 }
89 if _, resp := testRequest(t, ts, "GET", "/accounts/admin/", nil); resp != "accounts index" {
90 t.Fatalf(resp)
91 }
92 if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query", nil); resp != "admin" {
93 t.Fatalf(resp)
94 }
95 if _, resp := testRequest(t, ts, "GET", "/accounts/admin/query/", nil); resp != "admin" {
96 t.Fatalf(resp)
97 }
98 }
99
100 func TestRedirectSlashes(t *testing.T) {
101 r := chi.NewRouter()
102
103
104
105 r.Use(RedirectSlashes)
106
107 r.NotFound(func(w http.ResponseWriter, r *http.Request) {
108 w.WriteHeader(404)
109 w.Write([]byte("nothing here"))
110 })
111
112 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
113 w.Write([]byte("root"))
114 })
115
116 r.Route("/accounts/{accountID}", func(r chi.Router) {
117 r.Get("/", func(w http.ResponseWriter, r *http.Request) {
118 accountID := chi.URLParam(r, "accountID")
119 w.Write([]byte(accountID))
120 })
121 })
122
123 ts := httptest.NewServer(r)
124 defer ts.Close()
125
126 if resp, body := testRequest(t, ts, "GET", "/", nil); body != "root" && resp.StatusCode != 200 {
127 t.Fatalf(body)
128 }
129
130
131 if resp, body := testRequest(t, ts, "GET", "//", nil); body != "root" && resp.StatusCode != 200 {
132 t.Fatalf(body)
133 }
134
135 if resp, body := testRequest(t, ts, "GET", "/accounts/admin", nil); body != "admin" && resp.StatusCode != 200 {
136 t.Fatalf(body)
137 }
138
139
140 if resp, body := testRequest(t, ts, "GET", "/accounts/admin/", nil); body != "admin" && resp.StatusCode != 200 {
141 t.Fatalf(body)
142 }
143
144 if resp, body := testRequest(t, ts, "GET", "/nothing-here", nil); body != "nothing here" && resp.StatusCode != 200 {
145 t.Fatalf(body)
146 }
147
148
149 {
150 resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/", nil)
151 if resp.StatusCode != 301 {
152 t.Fatalf(body)
153 }
154 if resp.Header.Get("Location") != "/accounts/someuser" {
155 t.Fatalf("invalid redirection, should be /accounts/someuser")
156 }
157 }
158
159
160 {
161 resp, body := testRequestNoRedirect(t, ts, "GET", "/accounts/someuser/?a=1&b=2", nil)
162 if resp.StatusCode != 301 {
163 t.Fatalf(body)
164 }
165 if resp.Header.Get("Location") != "/accounts/someuser?a=1&b=2" {
166 t.Fatalf("invalid redirection, should be /accounts/someuser?a=1&b=2")
167 }
168
169 }
170 }
171
View as plain text