...

Source file src/github.com/gregjones/httpcache/httpcache.go

Documentation: github.com/gregjones/httpcache

     1  // Package httpcache provides a http.RoundTripper implementation that works as a
     2  // mostly RFC-compliant cache for http responses.
     3  //
     4  // It is only suitable for use as a 'private' cache (i.e. for a web-browser or an API-client
     5  // and not for a shared proxy).
     6  //
     7  package httpcache
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"errors"
    13  	"io"
    14  	"io/ioutil"
    15  	"net/http"
    16  	"net/http/httputil"
    17  	"strings"
    18  	"sync"
    19  	"time"
    20  )
    21  
    22  const (
    23  	stale = iota
    24  	fresh
    25  	transparent
    26  	// XFromCache is the header added to responses that are returned from the cache
    27  	XFromCache = "X-From-Cache"
    28  )
    29  
    30  // A Cache interface is used by the Transport to store and retrieve responses.
    31  type Cache interface {
    32  	// Get returns the []byte representation of a cached response and a bool
    33  	// set to true if the value isn't empty
    34  	Get(key string) (responseBytes []byte, ok bool)
    35  	// Set stores the []byte representation of a response against a key
    36  	Set(key string, responseBytes []byte)
    37  	// Delete removes the value associated with the key
    38  	Delete(key string)
    39  }
    40  
    41  // cacheKey returns the cache key for req.
    42  func cacheKey(req *http.Request) string {
    43  	if req.Method == http.MethodGet {
    44  		return req.URL.String()
    45  	} else {
    46  		return req.Method + " " + req.URL.String()
    47  	}
    48  }
    49  
    50  // CachedResponse returns the cached http.Response for req if present, and nil
    51  // otherwise.
    52  func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
    53  	cachedVal, ok := c.Get(cacheKey(req))
    54  	if !ok {
    55  		return
    56  	}
    57  
    58  	b := bytes.NewBuffer(cachedVal)
    59  	return http.ReadResponse(bufio.NewReader(b), req)
    60  }
    61  
    62  // MemoryCache is an implemtation of Cache that stores responses in an in-memory map.
    63  type MemoryCache struct {
    64  	mu    sync.RWMutex
    65  	items map[string][]byte
    66  }
    67  
    68  // Get returns the []byte representation of the response and true if present, false if not
    69  func (c *MemoryCache) Get(key string) (resp []byte, ok bool) {
    70  	c.mu.RLock()
    71  	resp, ok = c.items[key]
    72  	c.mu.RUnlock()
    73  	return resp, ok
    74  }
    75  
    76  // Set saves response resp to the cache with key
    77  func (c *MemoryCache) Set(key string, resp []byte) {
    78  	c.mu.Lock()
    79  	c.items[key] = resp
    80  	c.mu.Unlock()
    81  }
    82  
    83  // Delete removes key from the cache
    84  func (c *MemoryCache) Delete(key string) {
    85  	c.mu.Lock()
    86  	delete(c.items, key)
    87  	c.mu.Unlock()
    88  }
    89  
    90  // NewMemoryCache returns a new Cache that will store items in an in-memory map
    91  func NewMemoryCache() *MemoryCache {
    92  	c := &MemoryCache{items: map[string][]byte{}}
    93  	return c
    94  }
    95  
    96  // Transport is an implementation of http.RoundTripper that will return values from a cache
    97  // where possible (avoiding a network request) and will additionally add validators (etag/if-modified-since)
    98  // to repeated requests allowing servers to return 304 / Not Modified
    99  type Transport struct {
   100  	// The RoundTripper interface actually used to make requests
   101  	// If nil, http.DefaultTransport is used
   102  	Transport http.RoundTripper
   103  	Cache     Cache
   104  	// If true, responses returned from the cache will be given an extra header, X-From-Cache
   105  	MarkCachedResponses bool
   106  }
   107  
   108  // NewTransport returns a new Transport with the
   109  // provided Cache implementation and MarkCachedResponses set to true
   110  func NewTransport(c Cache) *Transport {
   111  	return &Transport{Cache: c, MarkCachedResponses: true}
   112  }
   113  
   114  // Client returns an *http.Client that caches responses.
   115  func (t *Transport) Client() *http.Client {
   116  	return &http.Client{Transport: t}
   117  }
   118  
   119  // varyMatches will return false unless all of the cached values for the headers listed in Vary
   120  // match the new request
   121  func varyMatches(cachedResp *http.Response, req *http.Request) bool {
   122  	for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") {
   123  		header = http.CanonicalHeaderKey(header)
   124  		if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) {
   125  			return false
   126  		}
   127  	}
   128  	return true
   129  }
   130  
   131  // RoundTrip takes a Request and returns a Response
   132  //
   133  // If there is a fresh Response already in cache, then it will be returned without connecting to
   134  // the server.
   135  //
   136  // If there is a stale Response, then any validators it contains will be set on the new request
   137  // to give the server a chance to respond with NotModified. If this happens, then the cached Response
   138  // will be returned.
   139  func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
   140  	cacheKey := cacheKey(req)
   141  	cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
   142  	var cachedResp *http.Response
   143  	if cacheable {
   144  		cachedResp, err = CachedResponse(t.Cache, req)
   145  	} else {
   146  		// Need to invalidate an existing value
   147  		t.Cache.Delete(cacheKey)
   148  	}
   149  
   150  	transport := t.Transport
   151  	if transport == nil {
   152  		transport = http.DefaultTransport
   153  	}
   154  
   155  	if cacheable && cachedResp != nil && err == nil {
   156  		if t.MarkCachedResponses {
   157  			cachedResp.Header.Set(XFromCache, "1")
   158  		}
   159  
   160  		if varyMatches(cachedResp, req) {
   161  			// Can only use cached value if the new request doesn't Vary significantly
   162  			freshness := getFreshness(cachedResp.Header, req.Header)
   163  			if freshness == fresh {
   164  				return cachedResp, nil
   165  			}
   166  
   167  			if freshness == stale {
   168  				var req2 *http.Request
   169  				// Add validators if caller hasn't already done so
   170  				etag := cachedResp.Header.Get("etag")
   171  				if etag != "" && req.Header.Get("etag") == "" {
   172  					req2 = cloneRequest(req)
   173  					req2.Header.Set("if-none-match", etag)
   174  				}
   175  				lastModified := cachedResp.Header.Get("last-modified")
   176  				if lastModified != "" && req.Header.Get("last-modified") == "" {
   177  					if req2 == nil {
   178  						req2 = cloneRequest(req)
   179  					}
   180  					req2.Header.Set("if-modified-since", lastModified)
   181  				}
   182  				if req2 != nil {
   183  					req = req2
   184  				}
   185  			}
   186  		}
   187  
   188  		resp, err = transport.RoundTrip(req)
   189  		if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {
   190  			// Replace the 304 response with the one from cache, but update with some new headers
   191  			endToEndHeaders := getEndToEndHeaders(resp.Header)
   192  			for _, header := range endToEndHeaders {
   193  				cachedResp.Header[header] = resp.Header[header]
   194  			}
   195  			resp = cachedResp
   196  		} else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
   197  			req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
   198  			// In case of transport failure and stale-if-error activated, returns cached content
   199  			// when available
   200  			return cachedResp, nil
   201  		} else {
   202  			if err != nil || resp.StatusCode != http.StatusOK {
   203  				t.Cache.Delete(cacheKey)
   204  			}
   205  			if err != nil {
   206  				return nil, err
   207  			}
   208  		}
   209  	} else {
   210  		reqCacheControl := parseCacheControl(req.Header)
   211  		if _, ok := reqCacheControl["only-if-cached"]; ok {
   212  			resp = newGatewayTimeoutResponse(req)
   213  		} else {
   214  			resp, err = transport.RoundTrip(req)
   215  			if err != nil {
   216  				return nil, err
   217  			}
   218  		}
   219  	}
   220  
   221  	if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
   222  		for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") {
   223  			varyKey = http.CanonicalHeaderKey(varyKey)
   224  			fakeHeader := "X-Varied-" + varyKey
   225  			reqValue := req.Header.Get(varyKey)
   226  			if reqValue != "" {
   227  				resp.Header.Set(fakeHeader, reqValue)
   228  			}
   229  		}
   230  		switch req.Method {
   231  		case "GET":
   232  			// Delay caching until EOF is reached.
   233  			resp.Body = &cachingReadCloser{
   234  				R: resp.Body,
   235  				OnEOF: func(r io.Reader) {
   236  					resp := *resp
   237  					resp.Body = ioutil.NopCloser(r)
   238  					respBytes, err := httputil.DumpResponse(&resp, true)
   239  					if err == nil {
   240  						t.Cache.Set(cacheKey, respBytes)
   241  					}
   242  				},
   243  			}
   244  		default:
   245  			respBytes, err := httputil.DumpResponse(resp, true)
   246  			if err == nil {
   247  				t.Cache.Set(cacheKey, respBytes)
   248  			}
   249  		}
   250  	} else {
   251  		t.Cache.Delete(cacheKey)
   252  	}
   253  	return resp, nil
   254  }
   255  
   256  // ErrNoDateHeader indicates that the HTTP headers contained no Date header.
   257  var ErrNoDateHeader = errors.New("no Date header")
   258  
   259  // Date parses and returns the value of the Date header.
   260  func Date(respHeaders http.Header) (date time.Time, err error) {
   261  	dateHeader := respHeaders.Get("date")
   262  	if dateHeader == "" {
   263  		err = ErrNoDateHeader
   264  		return
   265  	}
   266  
   267  	return time.Parse(time.RFC1123, dateHeader)
   268  }
   269  
   270  type realClock struct{}
   271  
   272  func (c *realClock) since(d time.Time) time.Duration {
   273  	return time.Since(d)
   274  }
   275  
   276  type timer interface {
   277  	since(d time.Time) time.Duration
   278  }
   279  
   280  var clock timer = &realClock{}
   281  
   282  // getFreshness will return one of fresh/stale/transparent based on the cache-control
   283  // values of the request and the response
   284  //
   285  // fresh indicates the response can be returned
   286  // stale indicates that the response needs validating before it is returned
   287  // transparent indicates the response should not be used to fulfil the request
   288  //
   289  // Because this is only a private cache, 'public' and 'private' in cache-control aren't
   290  // signficant. Similarly, smax-age isn't used.
   291  func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) {
   292  	respCacheControl := parseCacheControl(respHeaders)
   293  	reqCacheControl := parseCacheControl(reqHeaders)
   294  	if _, ok := reqCacheControl["no-cache"]; ok {
   295  		return transparent
   296  	}
   297  	if _, ok := respCacheControl["no-cache"]; ok {
   298  		return stale
   299  	}
   300  	if _, ok := reqCacheControl["only-if-cached"]; ok {
   301  		return fresh
   302  	}
   303  
   304  	date, err := Date(respHeaders)
   305  	if err != nil {
   306  		return stale
   307  	}
   308  	currentAge := clock.since(date)
   309  
   310  	var lifetime time.Duration
   311  	var zeroDuration time.Duration
   312  
   313  	// If a response includes both an Expires header and a max-age directive,
   314  	// the max-age directive overrides the Expires header, even if the Expires header is more restrictive.
   315  	if maxAge, ok := respCacheControl["max-age"]; ok {
   316  		lifetime, err = time.ParseDuration(maxAge + "s")
   317  		if err != nil {
   318  			lifetime = zeroDuration
   319  		}
   320  	} else {
   321  		expiresHeader := respHeaders.Get("Expires")
   322  		if expiresHeader != "" {
   323  			expires, err := time.Parse(time.RFC1123, expiresHeader)
   324  			if err != nil {
   325  				lifetime = zeroDuration
   326  			} else {
   327  				lifetime = expires.Sub(date)
   328  			}
   329  		}
   330  	}
   331  
   332  	if maxAge, ok := reqCacheControl["max-age"]; ok {
   333  		// the client is willing to accept a response whose age is no greater than the specified time in seconds
   334  		lifetime, err = time.ParseDuration(maxAge + "s")
   335  		if err != nil {
   336  			lifetime = zeroDuration
   337  		}
   338  	}
   339  	if minfresh, ok := reqCacheControl["min-fresh"]; ok {
   340  		//  the client wants a response that will still be fresh for at least the specified number of seconds.
   341  		minfreshDuration, err := time.ParseDuration(minfresh + "s")
   342  		if err == nil {
   343  			currentAge = time.Duration(currentAge + minfreshDuration)
   344  		}
   345  	}
   346  
   347  	if maxstale, ok := reqCacheControl["max-stale"]; ok {
   348  		// Indicates that the client is willing to accept a response that has exceeded its expiration time.
   349  		// If max-stale is assigned a value, then the client is willing to accept a response that has exceeded
   350  		// its expiration time by no more than the specified number of seconds.
   351  		// If no value is assigned to max-stale, then the client is willing to accept a stale response of any age.
   352  		//
   353  		// Responses served only because of a max-stale value are supposed to have a Warning header added to them,
   354  		// but that seems like a  hassle, and is it actually useful? If so, then there needs to be a different
   355  		// return-value available here.
   356  		if maxstale == "" {
   357  			return fresh
   358  		}
   359  		maxstaleDuration, err := time.ParseDuration(maxstale + "s")
   360  		if err == nil {
   361  			currentAge = time.Duration(currentAge - maxstaleDuration)
   362  		}
   363  	}
   364  
   365  	if lifetime > currentAge {
   366  		return fresh
   367  	}
   368  
   369  	return stale
   370  }
   371  
   372  // Returns true if either the request or the response includes the stale-if-error
   373  // cache control extension: https://tools.ietf.org/html/rfc5861
   374  func canStaleOnError(respHeaders, reqHeaders http.Header) bool {
   375  	respCacheControl := parseCacheControl(respHeaders)
   376  	reqCacheControl := parseCacheControl(reqHeaders)
   377  
   378  	var err error
   379  	lifetime := time.Duration(-1)
   380  
   381  	if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok {
   382  		if staleMaxAge != "" {
   383  			lifetime, err = time.ParseDuration(staleMaxAge + "s")
   384  			if err != nil {
   385  				return false
   386  			}
   387  		} else {
   388  			return true
   389  		}
   390  	}
   391  	if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok {
   392  		if staleMaxAge != "" {
   393  			lifetime, err = time.ParseDuration(staleMaxAge + "s")
   394  			if err != nil {
   395  				return false
   396  			}
   397  		} else {
   398  			return true
   399  		}
   400  	}
   401  
   402  	if lifetime >= 0 {
   403  		date, err := Date(respHeaders)
   404  		if err != nil {
   405  			return false
   406  		}
   407  		currentAge := clock.since(date)
   408  		if lifetime > currentAge {
   409  			return true
   410  		}
   411  	}
   412  
   413  	return false
   414  }
   415  
   416  func getEndToEndHeaders(respHeaders http.Header) []string {
   417  	// These headers are always hop-by-hop
   418  	hopByHopHeaders := map[string]struct{}{
   419  		"Connection":          {},
   420  		"Keep-Alive":          {},
   421  		"Proxy-Authenticate":  {},
   422  		"Proxy-Authorization": {},
   423  		"Te":                  {},
   424  		"Trailers":            {},
   425  		"Transfer-Encoding":   {},
   426  		"Upgrade":             {},
   427  	}
   428  
   429  	for _, extra := range strings.Split(respHeaders.Get("connection"), ",") {
   430  		// any header listed in connection, if present, is also considered hop-by-hop
   431  		if strings.Trim(extra, " ") != "" {
   432  			hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{}
   433  		}
   434  	}
   435  	endToEndHeaders := []string{}
   436  	for respHeader := range respHeaders {
   437  		if _, ok := hopByHopHeaders[respHeader]; !ok {
   438  			endToEndHeaders = append(endToEndHeaders, respHeader)
   439  		}
   440  	}
   441  	return endToEndHeaders
   442  }
   443  
   444  func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {
   445  	if _, ok := respCacheControl["no-store"]; ok {
   446  		return false
   447  	}
   448  	if _, ok := reqCacheControl["no-store"]; ok {
   449  		return false
   450  	}
   451  	return true
   452  }
   453  
   454  func newGatewayTimeoutResponse(req *http.Request) *http.Response {
   455  	var braw bytes.Buffer
   456  	braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n")
   457  	resp, err := http.ReadResponse(bufio.NewReader(&braw), req)
   458  	if err != nil {
   459  		panic(err)
   460  	}
   461  	return resp
   462  }
   463  
   464  // cloneRequest returns a clone of the provided *http.Request.
   465  // The clone is a shallow copy of the struct and its Header map.
   466  // (This function copyright goauth2 authors: https://code.google.com/p/goauth2)
   467  func cloneRequest(r *http.Request) *http.Request {
   468  	// shallow copy of the struct
   469  	r2 := new(http.Request)
   470  	*r2 = *r
   471  	// deep copy of the Header
   472  	r2.Header = make(http.Header)
   473  	for k, s := range r.Header {
   474  		r2.Header[k] = s
   475  	}
   476  	return r2
   477  }
   478  
   479  type cacheControl map[string]string
   480  
   481  func parseCacheControl(headers http.Header) cacheControl {
   482  	cc := cacheControl{}
   483  	ccHeader := headers.Get("Cache-Control")
   484  	for _, part := range strings.Split(ccHeader, ",") {
   485  		part = strings.Trim(part, " ")
   486  		if part == "" {
   487  			continue
   488  		}
   489  		if strings.ContainsRune(part, '=') {
   490  			keyval := strings.Split(part, "=")
   491  			cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",")
   492  		} else {
   493  			cc[part] = ""
   494  		}
   495  	}
   496  	return cc
   497  }
   498  
   499  // headerAllCommaSepValues returns all comma-separated values (each
   500  // with whitespace trimmed) for header name in headers. According to
   501  // Section 4.2 of the HTTP/1.1 spec
   502  // (http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2),
   503  // values from multiple occurrences of a header should be concatenated, if
   504  // the header's value is a comma-separated list.
   505  func headerAllCommaSepValues(headers http.Header, name string) []string {
   506  	var vals []string
   507  	for _, val := range headers[http.CanonicalHeaderKey(name)] {
   508  		fields := strings.Split(val, ",")
   509  		for i, f := range fields {
   510  			fields[i] = strings.TrimSpace(f)
   511  		}
   512  		vals = append(vals, fields...)
   513  	}
   514  	return vals
   515  }
   516  
   517  // cachingReadCloser is a wrapper around ReadCloser R that calls OnEOF
   518  // handler with a full copy of the content read from R when EOF is
   519  // reached.
   520  type cachingReadCloser struct {
   521  	// Underlying ReadCloser.
   522  	R io.ReadCloser
   523  	// OnEOF is called with a copy of the content of R when EOF is reached.
   524  	OnEOF func(io.Reader)
   525  
   526  	buf bytes.Buffer // buf stores a copy of the content of R.
   527  }
   528  
   529  // Read reads the next len(p) bytes from R or until R is drained. The
   530  // return value n is the number of bytes read. If R has no data to
   531  // return, err is io.EOF and OnEOF is called with a full copy of what
   532  // has been read so far.
   533  func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
   534  	n, err = r.R.Read(p)
   535  	r.buf.Write(p[:n])
   536  	if err == io.EOF {
   537  		r.OnEOF(bytes.NewReader(r.buf.Bytes()))
   538  	}
   539  	return n, err
   540  }
   541  
   542  func (r *cachingReadCloser) Close() error {
   543  	return r.R.Close()
   544  }
   545  
   546  // NewMemoryCacheTransport returns a new Transport using the in-memory cache implementation
   547  func NewMemoryCacheTransport() *Transport {
   548  	c := NewMemoryCache()
   549  	t := NewTransport(c)
   550  	return t
   551  }
   552  

View as plain text