...

Source file src/github.com/sassoftware/relic/lib/compresshttp/compress.go

Documentation: github.com/sassoftware/relic/lib/compresshttp

     1  //
     2  // Copyright (c) SAS Institute Inc.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  //
    16  
    17  package compresshttp
    18  
    19  import (
    20  	"compress/gzip"
    21  	"errors"
    22  	"io"
    23  	"io/ioutil"
    24  	"net/http"
    25  	"strings"
    26  	"sync/atomic"
    27  
    28  	"github.com/golang/snappy"
    29  )
    30  
    31  const (
    32  	contentEncoding  = "Content-Encoding"
    33  	contentLength    = "Content-Length"
    34  	EncodingIdentity = "identity"
    35  	EncodingGzip     = "gzip"
    36  	EncodingSnappy   = "x-snappy-framed"
    37  
    38  	AcceptedEncodings = EncodingSnappy + ", " + EncodingGzip
    39  )
    40  
    41  // higher is better
    42  var prefs = map[string]int{
    43  	EncodingGzip:   1,
    44  	EncodingSnappy: 2,
    45  }
    46  
    47  var ErrUnacceptableEncoding = errors.New("unknown Content-Encoding")
    48  
    49  func selectEncoding(acceptEncoding string) string {
    50  	var pref int
    51  	var best string
    52  	for _, encoding := range strings.Split(acceptEncoding, ",") {
    53  		encoding = strings.TrimSpace(strings.Split(encoding, ";")[0])
    54  		if p2 := prefs[encoding]; p2 > pref {
    55  			pref = p2
    56  			best = encoding
    57  		}
    58  	}
    59  	return best
    60  }
    61  
    62  func compress(encoding string, r io.Reader, w io.Writer) (err error) {
    63  	var compr io.WriteCloser
    64  	switch encoding {
    65  	case EncodingIdentity, "":
    66  		_, err = io.Copy(w, r)
    67  		return err
    68  	case EncodingGzip:
    69  		compr, err = gzip.NewWriterLevel(w, gzip.BestSpeed)
    70  	case EncodingSnappy:
    71  		compr = snappy.NewBufferedWriter(w)
    72  	default:
    73  		return ErrUnacceptableEncoding
    74  	}
    75  	if err == nil {
    76  		_, err = io.Copy(compr, r)
    77  	}
    78  	if err == nil {
    79  		err = compr.Close()
    80  	}
    81  	return
    82  }
    83  
    84  func decompress(encoding string, r io.Reader) (io.Reader, error) {
    85  	switch encoding {
    86  	case EncodingIdentity, "":
    87  		return ioutil.NopCloser(r), nil
    88  	case EncodingGzip:
    89  		return gzip.NewReader(r)
    90  	case EncodingSnappy:
    91  		return snappy.NewReader(r), nil
    92  	default:
    93  		return nil, ErrUnacceptableEncoding
    94  	}
    95  }
    96  
    97  func CompressRequest(request *http.Request, acceptEncoding string) error {
    98  	encoding := selectEncoding(acceptEncoding)
    99  	if encoding == "" {
   100  		return nil
   101  	}
   102  	plain := &readBlocker{Reader: request.Body}
   103  	pr, pw := io.Pipe()
   104  	go func() {
   105  		err := compress(encoding, plain, pw)
   106  		plain.Close()
   107  		pw.CloseWithError(err)
   108  	}()
   109  	// Ensure reads inside the goroutine fail after the request terminates.
   110  	// Otherwise there could be reads happening in parallel from multiple,
   111  	// different requests, if those requests are reading from the same
   112  	// underlying file. That could cause file pointers to move unexpectedly,
   113  	// and it's easier to prevent here than to make sure every use case is
   114  	// thread-safe.
   115  	request.Body = alsoClose{ReadCloser: pr, also: plain}
   116  	request.ContentLength = -1
   117  	request.Header.Set(contentEncoding, encoding)
   118  	return nil
   119  }
   120  
   121  // Wrap a reader and block all reads once Close() is called
   122  type readBlocker struct {
   123  	io.Reader
   124  	closed uint32
   125  }
   126  
   127  func (r *readBlocker) Read(d []byte) (int, error) {
   128  	if atomic.LoadUint32(&r.closed) != 0 {
   129  		return 0, errors.New("stream is closed")
   130  	}
   131  	return r.Reader.Read(d)
   132  }
   133  
   134  func (r *readBlocker) Close() error {
   135  	if c, ok := r.Reader.(io.Closer); ok {
   136  		if err := c.Close(); err != nil {
   137  			return err
   138  		}
   139  	}
   140  	atomic.StoreUint32(&r.closed, 1)
   141  	return nil
   142  }
   143  
   144  type alsoClose struct {
   145  	io.ReadCloser
   146  	also io.Closer
   147  }
   148  
   149  func (a alsoClose) Close() error {
   150  	a.also.Close()
   151  	return a.ReadCloser.Close()
   152  }
   153  
   154  func DecompressRequest(request *http.Request) error {
   155  	r, err := decompress(request.Header.Get(contentEncoding), request.Body)
   156  	if err == nil {
   157  		request.Body = ioutil.NopCloser(r)
   158  		request.ContentLength = -1
   159  	}
   160  	return err
   161  }
   162  
   163  func CompressResponse(r io.Reader, acceptEncoding string, writer http.ResponseWriter, status int) error {
   164  	encoding := selectEncoding(acceptEncoding)
   165  	if encoding != "" {
   166  		writer.Header().Set(contentEncoding, encoding)
   167  		writer.Header().Del(contentLength)
   168  	} else {
   169  		writer.Header().Del(contentEncoding)
   170  	}
   171  	writer.WriteHeader(status)
   172  	return compress(encoding, r, writer)
   173  }
   174  
   175  func DecompressResponse(response *http.Response) error {
   176  	r, err := decompress(response.Header.Get(contentEncoding), response.Body)
   177  	if err == nil {
   178  		response.Body = readAndClose{r: r, c: response.Body}
   179  		response.ContentLength = -1
   180  	}
   181  	return err
   182  }
   183  
   184  type readAndClose struct {
   185  	r io.Reader
   186  	c io.Closer
   187  }
   188  
   189  func (rc readAndClose) Read(d []byte) (int, error) {
   190  	return rc.r.Read(d)
   191  }
   192  
   193  func (rc readAndClose) Close() error {
   194  	return rc.c.Close()
   195  }
   196  

View as plain text