...

Source file src/github.com/go-chi/chi/middleware/compress.go

Documentation: github.com/go-chi/chi/middleware

     1  package middleware
     2  
     3  import (
     4  	"bufio"
     5  	"compress/flate"
     6  	"compress/gzip"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	"net"
    12  	"net/http"
    13  	"strings"
    14  	"sync"
    15  )
    16  
    17  var defaultCompressibleContentTypes = []string{
    18  	"text/html",
    19  	"text/css",
    20  	"text/plain",
    21  	"text/javascript",
    22  	"application/javascript",
    23  	"application/x-javascript",
    24  	"application/json",
    25  	"application/atom+xml",
    26  	"application/rss+xml",
    27  	"image/svg+xml",
    28  }
    29  
    30  // Compress is a middleware that compresses response
    31  // body of a given content types to a data format based
    32  // on Accept-Encoding request header. It uses a given
    33  // compression level.
    34  //
    35  // NOTE: make sure to set the Content-Type header on your response
    36  // otherwise this middleware will not compress the response body. For ex, in
    37  // your handler you should set w.Header().Set("Content-Type", http.DetectContentType(yourBody))
    38  // or set it manually.
    39  //
    40  // Passing a compression level of 5 is sensible value
    41  func Compress(level int, types ...string) func(next http.Handler) http.Handler {
    42  	compressor := NewCompressor(level, types...)
    43  	return compressor.Handler
    44  }
    45  
    46  // Compressor represents a set of encoding configurations.
    47  type Compressor struct {
    48  	level int // The compression level.
    49  	// The mapping of encoder names to encoder functions.
    50  	encoders map[string]EncoderFunc
    51  	// The mapping of pooled encoders to pools.
    52  	pooledEncoders map[string]*sync.Pool
    53  	// The set of content types allowed to be compressed.
    54  	allowedTypes     map[string]struct{}
    55  	allowedWildcards map[string]struct{}
    56  	// The list of encoders in order of decreasing precedence.
    57  	encodingPrecedence []string
    58  }
    59  
    60  // NewCompressor creates a new Compressor that will handle encoding responses.
    61  //
    62  // The level should be one of the ones defined in the flate package.
    63  // The types are the content types that are allowed to be compressed.
    64  func NewCompressor(level int, types ...string) *Compressor {
    65  	// If types are provided, set those as the allowed types. If none are
    66  	// provided, use the default list.
    67  	allowedTypes := make(map[string]struct{})
    68  	allowedWildcards := make(map[string]struct{})
    69  	if len(types) > 0 {
    70  		for _, t := range types {
    71  			if strings.Contains(strings.TrimSuffix(t, "/*"), "*") {
    72  				panic(fmt.Sprintf("middleware/compress: Unsupported content-type wildcard pattern '%s'. Only '/*' supported", t))
    73  			}
    74  			if strings.HasSuffix(t, "/*") {
    75  				allowedWildcards[strings.TrimSuffix(t, "/*")] = struct{}{}
    76  			} else {
    77  				allowedTypes[t] = struct{}{}
    78  			}
    79  		}
    80  	} else {
    81  		for _, t := range defaultCompressibleContentTypes {
    82  			allowedTypes[t] = struct{}{}
    83  		}
    84  	}
    85  
    86  	c := &Compressor{
    87  		level:            level,
    88  		encoders:         make(map[string]EncoderFunc),
    89  		pooledEncoders:   make(map[string]*sync.Pool),
    90  		allowedTypes:     allowedTypes,
    91  		allowedWildcards: allowedWildcards,
    92  	}
    93  
    94  	// Set the default encoders.  The precedence order uses the reverse
    95  	// ordering that the encoders were added. This means adding new encoders
    96  	// will move them to the front of the order.
    97  	//
    98  	// TODO:
    99  	// lzma: Opera.
   100  	// sdch: Chrome, Android. Gzip output + dictionary header.
   101  	// br:   Brotli, see https://github.com/go-chi/chi/pull/326
   102  
   103  	// HTTP 1.1 "deflate" (RFC 2616) stands for DEFLATE data (RFC 1951)
   104  	// wrapped with zlib (RFC 1950). The zlib wrapper uses Adler-32
   105  	// checksum compared to CRC-32 used in "gzip" and thus is faster.
   106  	//
   107  	// But.. some old browsers (MSIE, Safari 5.1) incorrectly expect
   108  	// raw DEFLATE data only, without the mentioned zlib wrapper.
   109  	// Because of this major confusion, most modern browsers try it
   110  	// both ways, first looking for zlib headers.
   111  	// Quote by Mark Adler: http://stackoverflow.com/a/9186091/385548
   112  	//
   113  	// The list of browsers having problems is quite big, see:
   114  	// http://zoompf.com/blog/2012/02/lose-the-wait-http-compression
   115  	// https://web.archive.org/web/20120321182910/http://www.vervestudios.co/projects/compression-tests/results
   116  	//
   117  	// That's why we prefer gzip over deflate. It's just more reliable
   118  	// and not significantly slower than gzip.
   119  	c.SetEncoder("deflate", encoderDeflate)
   120  
   121  	// TODO: Exception for old MSIE browsers that can't handle non-HTML?
   122  	// https://zoompf.com/blog/2012/02/lose-the-wait-http-compression
   123  	c.SetEncoder("gzip", encoderGzip)
   124  
   125  	// NOTE: Not implemented, intentionally:
   126  	// case "compress": // LZW. Deprecated.
   127  	// case "bzip2":    // Too slow on-the-fly.
   128  	// case "zopfli":   // Too slow on-the-fly.
   129  	// case "xz":       // Too slow on-the-fly.
   130  	return c
   131  }
   132  
   133  // SetEncoder can be used to set the implementation of a compression algorithm.
   134  //
   135  // The encoding should be a standardised identifier. See:
   136  // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Accept-Encoding
   137  //
   138  // For example, add the Brotli algortithm:
   139  //
   140  //  import brotli_enc "gopkg.in/kothar/brotli-go.v0/enc"
   141  //
   142  //  compressor := middleware.NewCompressor(5, "text/html")
   143  //  compressor.SetEncoder("br", func(w http.ResponseWriter, level int) io.Writer {
   144  //    params := brotli_enc.NewBrotliParams()
   145  //    params.SetQuality(level)
   146  //    return brotli_enc.NewBrotliWriter(params, w)
   147  //  })
   148  func (c *Compressor) SetEncoder(encoding string, fn EncoderFunc) {
   149  	encoding = strings.ToLower(encoding)
   150  	if encoding == "" {
   151  		panic("the encoding can not be empty")
   152  	}
   153  	if fn == nil {
   154  		panic("attempted to set a nil encoder function")
   155  	}
   156  
   157  	// If we are adding a new encoder that is already registered, we have to
   158  	// clear that one out first.
   159  	if _, ok := c.pooledEncoders[encoding]; ok {
   160  		delete(c.pooledEncoders, encoding)
   161  	}
   162  	if _, ok := c.encoders[encoding]; ok {
   163  		delete(c.encoders, encoding)
   164  	}
   165  
   166  	// If the encoder supports Resetting (IoReseterWriter), then it can be pooled.
   167  	encoder := fn(ioutil.Discard, c.level)
   168  	if encoder != nil {
   169  		if _, ok := encoder.(ioResetterWriter); ok {
   170  			pool := &sync.Pool{
   171  				New: func() interface{} {
   172  					return fn(ioutil.Discard, c.level)
   173  				},
   174  			}
   175  			c.pooledEncoders[encoding] = pool
   176  		}
   177  	}
   178  	// If the encoder is not in the pooledEncoders, add it to the normal encoders.
   179  	if _, ok := c.pooledEncoders[encoding]; !ok {
   180  		c.encoders[encoding] = fn
   181  	}
   182  
   183  	for i, v := range c.encodingPrecedence {
   184  		if v == encoding {
   185  			c.encodingPrecedence = append(c.encodingPrecedence[:i], c.encodingPrecedence[i+1:]...)
   186  		}
   187  	}
   188  
   189  	c.encodingPrecedence = append([]string{encoding}, c.encodingPrecedence...)
   190  }
   191  
   192  // Handler returns a new middleware that will compress the response based on the
   193  // current Compressor.
   194  func (c *Compressor) Handler(next http.Handler) http.Handler {
   195  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   196  		encoder, encoding, cleanup := c.selectEncoder(r.Header, w)
   197  
   198  		cw := &compressResponseWriter{
   199  			ResponseWriter:   w,
   200  			w:                w,
   201  			contentTypes:     c.allowedTypes,
   202  			contentWildcards: c.allowedWildcards,
   203  			encoding:         encoding,
   204  			compressable:     false, // determined in post-handler
   205  		}
   206  		if encoder != nil {
   207  			cw.w = encoder
   208  		}
   209  		// Re-add the encoder to the pool if applicable.
   210  		defer cleanup()
   211  		defer cw.Close()
   212  
   213  		next.ServeHTTP(cw, r)
   214  	})
   215  }
   216  
   217  // selectEncoder returns the encoder, the name of the encoder, and a closer function.
   218  func (c *Compressor) selectEncoder(h http.Header, w io.Writer) (io.Writer, string, func()) {
   219  	header := h.Get("Accept-Encoding")
   220  
   221  	// Parse the names of all accepted algorithms from the header.
   222  	accepted := strings.Split(strings.ToLower(header), ",")
   223  
   224  	// Find supported encoder by accepted list by precedence
   225  	for _, name := range c.encodingPrecedence {
   226  		if matchAcceptEncoding(accepted, name) {
   227  			if pool, ok := c.pooledEncoders[name]; ok {
   228  				encoder := pool.Get().(ioResetterWriter)
   229  				cleanup := func() {
   230  					pool.Put(encoder)
   231  				}
   232  				encoder.Reset(w)
   233  				return encoder, name, cleanup
   234  
   235  			}
   236  			if fn, ok := c.encoders[name]; ok {
   237  				return fn(w, c.level), name, func() {}
   238  			}
   239  		}
   240  
   241  	}
   242  
   243  	// No encoder found to match the accepted encoding
   244  	return nil, "", func() {}
   245  }
   246  
   247  func matchAcceptEncoding(accepted []string, encoding string) bool {
   248  	for _, v := range accepted {
   249  		if strings.Contains(v, encoding) {
   250  			return true
   251  		}
   252  	}
   253  	return false
   254  }
   255  
   256  // An EncoderFunc is a function that wraps the provided io.Writer with a
   257  // streaming compression algorithm and returns it.
   258  //
   259  // In case of failure, the function should return nil.
   260  type EncoderFunc func(w io.Writer, level int) io.Writer
   261  
   262  // Interface for types that allow resetting io.Writers.
   263  type ioResetterWriter interface {
   264  	io.Writer
   265  	Reset(w io.Writer)
   266  }
   267  
   268  type compressResponseWriter struct {
   269  	http.ResponseWriter
   270  
   271  	// The streaming encoder writer to be used if there is one. Otherwise,
   272  	// this is just the normal writer.
   273  	w                io.Writer
   274  	encoding         string
   275  	contentTypes     map[string]struct{}
   276  	contentWildcards map[string]struct{}
   277  	wroteHeader      bool
   278  	compressable     bool
   279  }
   280  
   281  func (cw *compressResponseWriter) isCompressable() bool {
   282  	// Parse the first part of the Content-Type response header.
   283  	contentType := cw.Header().Get("Content-Type")
   284  	if idx := strings.Index(contentType, ";"); idx >= 0 {
   285  		contentType = contentType[0:idx]
   286  	}
   287  
   288  	// Is the content type compressable?
   289  	if _, ok := cw.contentTypes[contentType]; ok {
   290  		return true
   291  	}
   292  	if idx := strings.Index(contentType, "/"); idx > 0 {
   293  		contentType = contentType[0:idx]
   294  		_, ok := cw.contentWildcards[contentType]
   295  		return ok
   296  	}
   297  	return false
   298  }
   299  
   300  func (cw *compressResponseWriter) WriteHeader(code int) {
   301  	if cw.wroteHeader {
   302  		cw.ResponseWriter.WriteHeader(code) // Allow multiple calls to propagate.
   303  		return
   304  	}
   305  	cw.wroteHeader = true
   306  	defer cw.ResponseWriter.WriteHeader(code)
   307  
   308  	// Already compressed data?
   309  	if cw.Header().Get("Content-Encoding") != "" {
   310  		return
   311  	}
   312  
   313  	if !cw.isCompressable() {
   314  		cw.compressable = false
   315  		return
   316  	}
   317  
   318  	if cw.encoding != "" {
   319  		cw.compressable = true
   320  		cw.Header().Set("Content-Encoding", cw.encoding)
   321  		cw.Header().Set("Vary", "Accept-Encoding")
   322  
   323  		// The content-length after compression is unknown
   324  		cw.Header().Del("Content-Length")
   325  	}
   326  }
   327  
   328  func (cw *compressResponseWriter) Write(p []byte) (int, error) {
   329  	if !cw.wroteHeader {
   330  		cw.WriteHeader(http.StatusOK)
   331  	}
   332  
   333  	return cw.writer().Write(p)
   334  }
   335  
   336  func (cw *compressResponseWriter) writer() io.Writer {
   337  	if cw.compressable {
   338  		return cw.w
   339  	} else {
   340  		return cw.ResponseWriter
   341  	}
   342  }
   343  
   344  type compressFlusher interface {
   345  	Flush() error
   346  }
   347  
   348  func (cw *compressResponseWriter) Flush() {
   349  	if f, ok := cw.writer().(http.Flusher); ok {
   350  		f.Flush()
   351  	}
   352  	// If the underlying writer has a compression flush signature,
   353  	// call this Flush() method instead
   354  	if f, ok := cw.writer().(compressFlusher); ok {
   355  		f.Flush()
   356  
   357  		// Also flush the underlying response writer
   358  		if f, ok := cw.ResponseWriter.(http.Flusher); ok {
   359  			f.Flush()
   360  		}
   361  	}
   362  }
   363  
   364  func (cw *compressResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
   365  	if hj, ok := cw.writer().(http.Hijacker); ok {
   366  		return hj.Hijack()
   367  	}
   368  	return nil, nil, errors.New("chi/middleware: http.Hijacker is unavailable on the writer")
   369  }
   370  
   371  func (cw *compressResponseWriter) Push(target string, opts *http.PushOptions) error {
   372  	if ps, ok := cw.writer().(http.Pusher); ok {
   373  		return ps.Push(target, opts)
   374  	}
   375  	return errors.New("chi/middleware: http.Pusher is unavailable on the writer")
   376  }
   377  
   378  func (cw *compressResponseWriter) Close() error {
   379  	if c, ok := cw.writer().(io.WriteCloser); ok {
   380  		return c.Close()
   381  	}
   382  	return errors.New("chi/middleware: io.WriteCloser is unavailable on the writer")
   383  }
   384  
   385  func encoderGzip(w io.Writer, level int) io.Writer {
   386  	gw, err := gzip.NewWriterLevel(w, level)
   387  	if err != nil {
   388  		return nil
   389  	}
   390  	return gw
   391  }
   392  
   393  func encoderDeflate(w io.Writer, level int) io.Writer {
   394  	dw, err := flate.NewWriter(w, level)
   395  	if err != nil {
   396  		return nil
   397  	}
   398  	return dw
   399  }
   400  

View as plain text