...

Source file src/github.com/klauspost/compress/gzhttp/transport.go

Documentation: github.com/klauspost/compress/gzhttp

     1  // Copyright (c) 2021 Klaus Post. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package gzhttp
     6  
     7  import (
     8  	"io"
     9  	"net/http"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/klauspost/compress/gzip"
    14  	"github.com/klauspost/compress/zstd"
    15  )
    16  
    17  // Transport will wrap an HTTP transport with a custom handler
    18  // that will request gzip and automatically decompress it.
    19  // Using this is significantly faster than using the default transport.
    20  func Transport(parent http.RoundTripper, opts ...transportOption) http.RoundTripper {
    21  	g := gzRoundtripper{parent: parent, withZstd: true, withGzip: true}
    22  	for _, o := range opts {
    23  		o(&g)
    24  	}
    25  	var ae []string
    26  	if g.withZstd {
    27  		ae = append(ae, "zstd")
    28  	}
    29  	if g.withGzip {
    30  		ae = append(ae, "gzip")
    31  	}
    32  	g.acceptEncoding = strings.Join(ae, ",")
    33  	return &g
    34  }
    35  
    36  type transportOption func(c *gzRoundtripper)
    37  
    38  // TransportEnableZstd will send Zstandard as a compression option to the server.
    39  // Enabled by default, but may be disabled if future problems arise.
    40  func TransportEnableZstd(b bool) transportOption {
    41  	return func(c *gzRoundtripper) {
    42  		c.withZstd = b
    43  	}
    44  }
    45  
    46  // TransportEnableGzip will send Gzip as a compression option to the server.
    47  // Enabled by default.
    48  func TransportEnableGzip(b bool) transportOption {
    49  	return func(c *gzRoundtripper) {
    50  		c.withGzip = b
    51  	}
    52  }
    53  
    54  // TransportCustomEval will send the header of a response to a custom function.
    55  // If the function returns false, the response will be returned as-is,
    56  // Otherwise it will be decompressed based on Content-Encoding field, regardless
    57  // of whether the transport added the encoding.
    58  func TransportCustomEval(fn func(header http.Header) bool) transportOption {
    59  	return func(c *gzRoundtripper) {
    60  		c.customEval = fn
    61  	}
    62  }
    63  
    64  type gzRoundtripper struct {
    65  	parent             http.RoundTripper
    66  	acceptEncoding     string
    67  	withZstd, withGzip bool
    68  	customEval         func(header http.Header) bool
    69  }
    70  
    71  func (g *gzRoundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
    72  	var requestedComp bool
    73  	if req.Header.Get("Accept-Encoding") == "" &&
    74  		req.Header.Get("Range") == "" &&
    75  		req.Method != "HEAD" {
    76  		// Request gzip only, not deflate. Deflate is ambiguous and
    77  		// not as universally supported anyway.
    78  		// See: https://zlib.net/zlib_faq.html#faq39
    79  		//
    80  		// Note that we don't request this for HEAD requests,
    81  		// due to a bug in nginx:
    82  		//   https://trac.nginx.org/nginx/ticket/358
    83  		//   https://golang.org/issue/5522
    84  		//
    85  		// We don't request gzip if the request is for a range, since
    86  		// auto-decoding a portion of a gzipped document will just fail
    87  		// anyway. See https://golang.org/issue/8923
    88  		requestedComp = len(g.acceptEncoding) > 0
    89  		req.Header.Set("Accept-Encoding", g.acceptEncoding)
    90  	}
    91  
    92  	resp, err := g.parent.RoundTrip(req)
    93  	if err != nil || !requestedComp {
    94  		return resp, err
    95  	}
    96  	decompress := false
    97  	if g.customEval != nil {
    98  		if !g.customEval(resp.Header) {
    99  			return resp, nil
   100  		}
   101  		decompress = true
   102  	}
   103  	// Decompress
   104  	if (decompress || g.withGzip) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "gzip") {
   105  		resp.Body = &gzipReader{body: resp.Body}
   106  		resp.Header.Del("Content-Encoding")
   107  		resp.Header.Del("Content-Length")
   108  		resp.ContentLength = -1
   109  		resp.Uncompressed = true
   110  	}
   111  	if (decompress || g.withZstd) && asciiEqualFold(resp.Header.Get("Content-Encoding"), "zstd") {
   112  		resp.Body = &zstdReader{body: resp.Body}
   113  		resp.Header.Del("Content-Encoding")
   114  		resp.Header.Del("Content-Length")
   115  		resp.ContentLength = -1
   116  		resp.Uncompressed = true
   117  	}
   118  
   119  	return resp, nil
   120  }
   121  
   122  var gzReaderPool sync.Pool
   123  
   124  // gzipReader wraps a response body so it can lazily
   125  // call gzip.NewReader on the first call to Read
   126  type gzipReader struct {
   127  	body io.ReadCloser // underlying HTTP/1 response body framing
   128  	zr   *gzip.Reader  // lazily-initialized gzip reader
   129  	zerr error         // any error from gzip.NewReader; sticky
   130  }
   131  
   132  func (gz *gzipReader) Read(p []byte) (n int, err error) {
   133  	if gz.zr == nil {
   134  		if gz.zerr == nil {
   135  			zr, ok := gzReaderPool.Get().(*gzip.Reader)
   136  			if ok {
   137  				gz.zr, gz.zerr = zr, zr.Reset(gz.body)
   138  			} else {
   139  				gz.zr, gz.zerr = gzip.NewReader(gz.body)
   140  			}
   141  		}
   142  		if gz.zerr != nil {
   143  			return 0, gz.zerr
   144  		}
   145  	}
   146  
   147  	return gz.zr.Read(p)
   148  }
   149  
   150  func (gz *gzipReader) Close() error {
   151  	if gz.zr != nil {
   152  		gzReaderPool.Put(gz.zr)
   153  		gz.zr = nil
   154  	}
   155  	return gz.body.Close()
   156  }
   157  
   158  // asciiEqualFold is strings.EqualFold, ASCII only. It reports whether s and t
   159  // are equal, ASCII-case-insensitively.
   160  func asciiEqualFold(s, t string) bool {
   161  	if len(s) != len(t) {
   162  		return false
   163  	}
   164  	for i := 0; i < len(s); i++ {
   165  		if lower(s[i]) != lower(t[i]) {
   166  			return false
   167  		}
   168  	}
   169  	return true
   170  }
   171  
   172  // lower returns the ASCII lowercase version of b.
   173  func lower(b byte) byte {
   174  	if 'A' <= b && b <= 'Z' {
   175  		return b + ('a' - 'A')
   176  	}
   177  	return b
   178  }
   179  
   180  // zstdReaderPool pools zstd decoders.
   181  var zstdReaderPool sync.Pool
   182  
   183  // zstdReader wraps a response body so it can lazily
   184  // call gzip.NewReader on the first call to Read
   185  type zstdReader struct {
   186  	body io.ReadCloser // underlying HTTP/1 response body framing
   187  	zr   *zstd.Decoder // lazily-initialized gzip reader
   188  	zerr error         // any error from zstd.NewReader; sticky
   189  }
   190  
   191  func (zr *zstdReader) Read(p []byte) (n int, err error) {
   192  	if zr.zerr != nil {
   193  		return 0, zr.zerr
   194  	}
   195  	if zr.zr == nil {
   196  		if zr.zerr == nil {
   197  			reader, ok := zstdReaderPool.Get().(*zstd.Decoder)
   198  			if ok {
   199  				zr.zerr = reader.Reset(zr.body)
   200  				zr.zr = reader
   201  			} else {
   202  				zr.zr, zr.zerr = zstd.NewReader(zr.body, zstd.WithDecoderLowmem(true), zstd.WithDecoderMaxWindow(32<<20), zstd.WithDecoderConcurrency(1))
   203  			}
   204  		}
   205  		if zr.zerr != nil {
   206  			return 0, zr.zerr
   207  		}
   208  	}
   209  	n, err = zr.zr.Read(p)
   210  	if err != nil {
   211  		// Usually this will be io.EOF,
   212  		// stash the decoder and keep the error.
   213  		zr.zr.Reset(nil)
   214  		zstdReaderPool.Put(zr.zr)
   215  		zr.zr = nil
   216  		zr.zerr = err
   217  	}
   218  	return
   219  }
   220  
   221  func (zr *zstdReader) Close() error {
   222  	if zr.zr != nil {
   223  		zr.zr.Reset(nil)
   224  		zstdReaderPool.Put(zr.zr)
   225  		zr.zr = nil
   226  	}
   227  	return zr.body.Close()
   228  }
   229  

View as plain text