1 package middleware
2
3 import (
4 "compress/flate"
5 "compress/gzip"
6 "fmt"
7 "io"
8 "io/ioutil"
9 "net/http"
10 "net/http/httptest"
11 "strings"
12 "testing"
13
14 "github.com/go-chi/chi"
15 )
16
17 func TestCompressor(t *testing.T) {
18 r := chi.NewRouter()
19
20 compressor := NewCompressor(5, "text/html", "text/css")
21 if len(compressor.encoders) != 0 || len(compressor.pooledEncoders) != 2 {
22 t.Errorf("gzip and deflate should be pooled")
23 }
24
25 compressor.SetEncoder("nop", func(w io.Writer, _ int) io.Writer {
26 return w
27 })
28
29 if len(compressor.encoders) != 1 {
30 t.Errorf("nop encoder should be stored in the encoders map")
31 }
32
33 r.Use(compressor.Handler)
34
35 r.Get("/gethtml", func(w http.ResponseWriter, r *http.Request) {
36 w.Header().Set("Content-Type", "text/html")
37 w.Write([]byte("textstring"))
38 })
39
40 r.Get("/getcss", func(w http.ResponseWriter, r *http.Request) {
41 w.Header().Set("Content-Type", "text/html")
42 w.Write([]byte("textstring"))
43 })
44
45 r.Get("/getplain", func(w http.ResponseWriter, r *http.Request) {
46 w.Header().Set("Content-Type", "text/html")
47 w.Write([]byte("textstring"))
48 })
49
50 ts := httptest.NewServer(r)
51 defer ts.Close()
52
53 tests := []struct {
54 name string
55 path string
56 acceptedEncodings []string
57 expectedEncoding string
58 }{
59 {
60 name: "no expected encodings due to no accepted encodings",
61 path: "/gethtml",
62 acceptedEncodings: nil,
63 expectedEncoding: "",
64 },
65 {
66 name: "no expected encodings due to content type",
67 path: "/getplain",
68 acceptedEncodings: nil,
69 expectedEncoding: "",
70 },
71 {
72 name: "gzip is only encoding",
73 path: "/gethtml",
74 acceptedEncodings: []string{"gzip"},
75 expectedEncoding: "gzip",
76 },
77 {
78 name: "gzip is preferred over deflate",
79 path: "/getcss",
80 acceptedEncodings: []string{"gzip", "deflate"},
81 expectedEncoding: "gzip",
82 },
83 {
84 name: "deflate is used",
85 path: "/getcss",
86 acceptedEncodings: []string{"deflate"},
87 expectedEncoding: "deflate",
88 },
89 {
90
91 name: "nop is preferred",
92 path: "/getcss",
93 acceptedEncodings: []string{"nop, gzip, deflate"},
94 expectedEncoding: "nop",
95 },
96 }
97
98 for _, tc := range tests {
99 tc := tc
100 t.Run(tc.name, func(t *testing.T) {
101 resp, respString := testRequestWithAcceptedEncodings(t, ts, "GET", tc.path, tc.acceptedEncodings...)
102 if respString != "textstring" {
103 t.Errorf("response text doesn't match; expected:%q, got:%q", "textstring", respString)
104 }
105 if got := resp.Header.Get("Content-Encoding"); got != tc.expectedEncoding {
106 t.Errorf("expected encoding %q but got %q", tc.expectedEncoding, got)
107 }
108
109 })
110
111 }
112 }
113
114 func TestCompressorWildcards(t *testing.T) {
115 tests := []struct {
116 name string
117 types []string
118 typesCount int
119 wcCount int
120 recover string
121 }{
122 {
123 name: "defaults",
124 typesCount: 10,
125 },
126 {
127 name: "no wildcard",
128 types: []string{"text/plain", "text/html"},
129 typesCount: 2,
130 },
131 {
132 name: "invalid wildcard #1",
133 types: []string{"audio/*wav"},
134 recover: "middleware/compress: Unsupported content-type wildcard pattern 'audio/*wav'. Only '/*' supported",
135 },
136 {
137 name: "invalid wildcard #2",
138 types: []string{"application*/*"},
139 recover: "middleware/compress: Unsupported content-type wildcard pattern 'application*/*'. Only '/*' supported",
140 },
141 {
142 name: "valid wildcard",
143 types: []string{"text/*"},
144 wcCount: 1,
145 },
146 {
147 name: "mixed",
148 types: []string{"audio/wav", "text/*"},
149 typesCount: 1,
150 wcCount: 1,
151 },
152 }
153 for _, tt := range tests {
154 t.Run(tt.name, func(t *testing.T) {
155 defer func() {
156 if tt.recover == "" {
157 tt.recover = "<nil>"
158 }
159 if r := recover(); tt.recover != fmt.Sprintf("%v", r) {
160 t.Errorf("Unexpected value recovered: %v", r)
161 }
162 }()
163 compressor := NewCompressor(5, tt.types...)
164 if len(compressor.allowedTypes) != tt.typesCount {
165 t.Errorf("expected %d allowedTypes, got %d", tt.typesCount, len(compressor.allowedTypes))
166 }
167 if len(compressor.allowedWildcards) != tt.wcCount {
168 t.Errorf("expected %d allowedWildcards, got %d", tt.wcCount, len(compressor.allowedWildcards))
169 }
170 })
171 }
172 }
173
174 func testRequestWithAcceptedEncodings(t *testing.T, ts *httptest.Server, method, path string, encodings ...string) (*http.Response, string) {
175 req, err := http.NewRequest(method, ts.URL+path, nil)
176 if err != nil {
177 t.Fatal(err)
178 return nil, ""
179 }
180 if len(encodings) > 0 {
181 encodingsString := strings.Join(encodings, ",")
182 req.Header.Set("Accept-Encoding", encodingsString)
183 }
184
185 resp, err := http.DefaultClient.Do(req)
186 if err != nil {
187 t.Fatal(err)
188 return nil, ""
189 }
190
191 respBody := decodeResponseBody(t, resp)
192 defer resp.Body.Close()
193
194 return resp, respBody
195 }
196
197 func decodeResponseBody(t *testing.T, resp *http.Response) string {
198 var reader io.ReadCloser
199 switch resp.Header.Get("Content-Encoding") {
200 case "gzip":
201 var err error
202 reader, err = gzip.NewReader(resp.Body)
203 if err != nil {
204 t.Fatal(err)
205 }
206 case "deflate":
207 reader = flate.NewReader(resp.Body)
208 default:
209 reader = resp.Body
210 }
211 respBody, err := ioutil.ReadAll(reader)
212 if err != nil {
213 t.Fatal(err)
214 return ""
215 }
216 reader.Close()
217
218 return string(respBody)
219 }
220
View as plain text