1 package middleware
2
3 import (
4 "bufio"
5 "compress/flate"
6 "compress/gzip"
7 "errors"
8 "fmt"
9 "io"
10 "io/ioutil"
11 "net"
12 "net/http"
13 "strings"
14 "sync"
15 )
16
17 var defaultCompressibleContentTypes = []string{
18 "text/html",
19 "text/css",
20 "text/plain",
21 "text/javascript",
22 "application/javascript",
23 "application/x-javascript",
24 "application/json",
25 "application/atom+xml",
26 "application/rss+xml",
27 "image/svg+xml",
28 }
29
30
31
32
33
34
35
36
37
38
39
40
41 func Compress(level int, types ...string) func(next http.Handler) http.Handler {
42 compressor := NewCompressor(level, types...)
43 return compressor.Handler
44 }
45
46
47 type Compressor struct {
48 level int
49
50 encoders map[string]EncoderFunc
51
52 pooledEncoders map[string]*sync.Pool
53
54 allowedTypes map[string]struct{}
55 allowedWildcards map[string]struct{}
56
57 encodingPrecedence []string
58 }
59
60
61
62
63
64 func NewCompressor(level int, types ...string) *Compressor {
65
66
67 allowedTypes := make(map[string]struct{})
68 allowedWildcards := make(map[string]struct{})
69 if len(types) > 0 {
70 for _, t := range types {
71 if strings.Contains(strings.TrimSuffix(t, "/*"), "*") {
72 panic(fmt.Sprintf("middleware/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t))
73 }
74 if strings.HasSuffix(t, "/*") {
75 allowedWildcards[strings.TrimSuffix(t, "/*")] = struct{}{}
76 } else {
77 allowedTypes[t] = struct{}{}
78 }
79 }
80 } else {
81 for _, t := range defaultCompressibleContentTypes {
82 allowedTypes[t] = struct{}{}
83 }
84 }
85
86 c := &Compressor{
87 level: level,
88 encoders: make(map[string]EncoderFunc),
89 pooledEncoders: make(map[string]*sync.Pool),
90 allowedTypes: allowedTypes,
91 allowedWildcards: allowedWildcards,
92 }
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119 c.SetEncoder("deflate", encoderDeflate)
120
121
122
123 c.SetEncoder("gzip", encoderGzip)
124
125
126
127
128
129
130 return c
131 }
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148 func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
149 encoding = strings.ToLower(encoding)
150 if encoding == "" {
151 panic("the encoding can not be empty")
152 }
153 if fn == nil {
154 panic("attempted to set a nil encoder function")
155 }
156
157
158
159 if _, ok := c.pooledEncoders[encoding]; ok {
160 delete(c.pooledEncoders, encoding)
161 }
162 if _, ok := c.encoders[encoding]; ok {
163 delete(c.encoders, encoding)
164 }
165
166
167 encoder := fn(ioutil.Discard, c.level)
168 if encoder != nil {
169 if _, ok := encoder.(ioResetterWriter); ok {
170 pool := &sync.Pool{
171 New: func() interface{} {
172 return fn(ioutil.Discard, c.level)
173 },
174 }
175 c.pooledEncoders[encoding] = pool
176 }
177 }
178
179 if _, ok := c.pooledEncoders[encoding]; !ok {
180 c.encoders[encoding] = fn
181 }
182
183 for i, v := range c.encodingPrecedence {
184 if v == encoding {
185 c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
186 }
187 }
188
189 c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
190 }
191
192
193
194 func (c *Compressor) Handler(next http.Handler) http.Handler {
195 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
196 encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
197
198 cw := &compressResponseWriter{
199 ResponseWriter: w,
200 w: w,
201 contentTypes: c.allowedTypes,
202 contentWildcards: c.allowedWildcards,
203 encoding: encoding,
204 compressable: false,
205 }
206 if encoder != nil {
207 cw.w = encoder
208 }
209
210 defer cleanup()
211 defer cw.Close()
212
213 next.ServeHTTP(cw, r)
214 })
215 }
216
217
218 func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
219 header := h.Get("Accept-Encoding")
220
221
222 accepted := strings.Split(strings.ToLower(header), ",")
223
224
225 for _, name := range c.encodingPrecedence {
226 if matchAcceptEncoding(accepted, name) {
227 if pool, ok := c.pooledEncoders[name]; ok {
228 encoder := pool.Get().(ioResetterWriter)
229 cleanup := func() {
230 pool.Put(encoder)
231 }
232 encoder.Reset(w)
233 return encoder, name, cleanup
234
235 }
236 if fn, ok := c.encoders[name]; ok {
237 return fn(w, c.level), name, func() {}
238 }
239 }
240
241 }
242
243
244 return nil, "", func() {}
245 }
246
247 func matchAcceptEncoding(accepted []string, encoding string) bool {
248 for _, v := range accepted {
249 if strings.Contains(v, encoding) {
250 return true
251 }
252 }
253 return false
254 }
255
256
257
258
259
260 type EncoderFunc func(w io.Writer, level int) io.Writer
261
262
263 type ioResetterWriter interface {
264 io.Writer
265 Reset(w io.Writer)
266 }
267
268 type compressResponseWriter struct {
269 http.ResponseWriter
270
271
272
273 w io.Writer
274 encoding string
275 contentTypes map[string]struct{}
276 contentWildcards map[string]struct{}
277 wroteHeader bool
278 compressable bool
279 }
280
281 func (cw *compressResponseWriter) isCompressable() bool {
282
283 contentType := cw.Header().Get("Content-Type")
284 if idx := strings.Index(contentType, ";"); idx >= 0 {
285 contentType = contentType[0:idx]
286 }
287
288
289 if _, ok := cw.contentTypes[contentType]; ok {
290 return true
291 }
292 if idx := strings.Index(contentType, "/"); idx > 0 {
293 contentType = contentType[0:idx]
294 _, ok := cw.contentWildcards[contentType]
295 return ok
296 }
297 return false
298 }
299
300 func (cw *compressResponseWriter) WriteHeader(code int) {
301 if cw.wroteHeader {
302 cw.ResponseWriter.WriteHeader(code)
303 return
304 }
305 cw.wroteHeader = true
306 defer cw.ResponseWriter.WriteHeader(code)
307
308
309 if cw.Header().Get("Content-Encoding") != "" {
310 return
311 }
312
313 if !cw.isCompressable() {
314 cw.compressable = false
315 return
316 }
317
318 if cw.encoding != "" {
319 cw.compressable = true
320 cw.Header().Set("Content-Encoding", cw.encoding)
321 cw.Header().Set("Vary", "Accept-Encoding")
322
323
324 cw.Header().Del("Content-Length")
325 }
326 }
327
328 func (cw *compressResponseWriter) Write(p []byte) (int, error) {
329 if !cw.wroteHeader {
330 cw.WriteHeader(http.StatusOK)
331 }
332
333 return cw.writer().Write(p)
334 }
335
336 func (cw *compressResponseWriter) writer() io.Writer {
337 if cw.compressable {
338 return cw.w
339 } else {
340 return cw.ResponseWriter
341 }
342 }
343
344 type compressFlusher interface {
345 Flush() error
346 }
347
348 func (cw *compressResponseWriter) Flush() {
349 if f, ok := cw.writer().(http.Flusher); ok {
350 f.Flush()
351 }
352
353
354 if f, ok := cw.writer().(compressFlusher); ok {
355 f.Flush()
356
357
358 if f, ok := cw.ResponseWriter.(http.Flusher); ok {
359 f.Flush()
360 }
361 }
362 }
363
364 func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
365 if hj, ok := cw.writer().(http.Hijacker); ok {
366 return hj.Hijack()
367 }
368 return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer")
369 }
370
371 func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
372 if ps, ok := cw.writer().(http.Pusher); ok {
373 return ps.Push(target, opts)
374 }
375 return errors.New("chi/middleware: http.Pusher is unavailable on the writer")
376 }
377
378 func (cw *compressResponseWriter) Close() error {
379 if c, ok := cw.writer().(io.WriteCloser); ok {
380 return c.Close()
381 }
382 return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer")
383 }
384
385 func encoderGzip(w io.Writer, level int) io.Writer {
386 gw, err := gzip.NewWriterLevel(w, level)
387 if err != nil {
388 return nil
389 }
390 return gw
391 }
392
393 func encoderDeflate(w io.Writer, level int) io.Writer {
394 dw, err := flate.NewWriter(w, level)
395 if err != nil {
396 return nil
397 }
398 return dw
399 }
400
View as plain text