1
2
3
4
5
6
7 package httpcache
8
9 import (
10 "bufio"
11 "bytes"
12 "errors"
13 "io"
14 "io/ioutil"
15 "net/http"
16 "net/http/httputil"
17 "strings"
18 "sync"
19 "time"
20 )
21
22 const (
23 stale = iota
24 fresh
25 transparent
26
27 XFromCache = "X-From-Cache"
28 )
29
30
31 type Cache interface {
32
33
34 Get(key string) (responseBytes []byte, ok bool)
35
36 Set(key string, responseBytes []byte)
37
38 Delete(key string)
39 }
40
41
42 func cacheKey(req *http.Request) string {
43 if req.Method == http.MethodGet {
44 return req.URL.String()
45 } else {
46 return req.Method + " " + req.URL.String()
47 }
48 }
49
50
51
52 func CachedResponse(c Cache, req *http.Request) (resp *http.Response, err error) {
53 cachedVal, ok := c.Get(cacheKey(req))
54 if !ok {
55 return
56 }
57
58 b := bytes.NewBuffer(cachedVal)
59 return http.ReadResponse(bufio.NewReader(b), req)
60 }
61
62
63 type MemoryCache struct {
64 mu sync.RWMutex
65 items map[string][]byte
66 }
67
68
69 func (c *MemoryCache) Get(key string) (resp []byte, ok bool) {
70 c.mu.RLock()
71 resp, ok = c.items[key]
72 c.mu.RUnlock()
73 return resp, ok
74 }
75
76
77 func (c *MemoryCache) Set(key string, resp []byte) {
78 c.mu.Lock()
79 c.items[key] = resp
80 c.mu.Unlock()
81 }
82
83
84 func (c *MemoryCache) Delete(key string) {
85 c.mu.Lock()
86 delete(c.items, key)
87 c.mu.Unlock()
88 }
89
90
91 func NewMemoryCache() *MemoryCache {
92 c := &MemoryCache{items: map[string][]byte{}}
93 return c
94 }
95
96
97
98
99 type Transport struct {
100
101
102 Transport http.RoundTripper
103 Cache Cache
104
105 MarkCachedResponses bool
106 }
107
108
109
110 func NewTransport(c Cache) *Transport {
111 return &Transport{Cache: c, MarkCachedResponses: true}
112 }
113
114
115 func (t *Transport) Client() *http.Client {
116 return &http.Client{Transport: t}
117 }
118
119
120
121 func varyMatches(cachedResp *http.Response, req *http.Request) bool {
122 for _, header := range headerAllCommaSepValues(cachedResp.Header, "vary") {
123 header = http.CanonicalHeaderKey(header)
124 if header != "" && req.Header.Get(header) != cachedResp.Header.Get("X-Varied-"+header) {
125 return false
126 }
127 }
128 return true
129 }
130
131
132
133
134
135
136
137
138
139 func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
140 cacheKey := cacheKey(req)
141 cacheable := (req.Method == "GET" || req.Method == "HEAD") && req.Header.Get("range") == ""
142 var cachedResp *http.Response
143 if cacheable {
144 cachedResp, err = CachedResponse(t.Cache, req)
145 } else {
146
147 t.Cache.Delete(cacheKey)
148 }
149
150 transport := t.Transport
151 if transport == nil {
152 transport = http.DefaultTransport
153 }
154
155 if cacheable && cachedResp != nil && err == nil {
156 if t.MarkCachedResponses {
157 cachedResp.Header.Set(XFromCache, "1")
158 }
159
160 if varyMatches(cachedResp, req) {
161
162 freshness := getFreshness(cachedResp.Header, req.Header)
163 if freshness == fresh {
164 return cachedResp, nil
165 }
166
167 if freshness == stale {
168 var req2 *http.Request
169
170 etag := cachedResp.Header.Get("etag")
171 if etag != "" && req.Header.Get("etag") == "" {
172 req2 = cloneRequest(req)
173 req2.Header.Set("if-none-match", etag)
174 }
175 lastModified := cachedResp.Header.Get("last-modified")
176 if lastModified != "" && req.Header.Get("last-modified") == "" {
177 if req2 == nil {
178 req2 = cloneRequest(req)
179 }
180 req2.Header.Set("if-modified-since", lastModified)
181 }
182 if req2 != nil {
183 req = req2
184 }
185 }
186 }
187
188 resp, err = transport.RoundTrip(req)
189 if err == nil && req.Method == "GET" && resp.StatusCode == http.StatusNotModified {
190
191 endToEndHeaders := getEndToEndHeaders(resp.Header)
192 for _, header := range endToEndHeaders {
193 cachedResp.Header[header] = resp.Header[header]
194 }
195 resp = cachedResp
196 } else if (err != nil || (cachedResp != nil && resp.StatusCode >= 500)) &&
197 req.Method == "GET" && canStaleOnError(cachedResp.Header, req.Header) {
198
199
200 return cachedResp, nil
201 } else {
202 if err != nil || resp.StatusCode != http.StatusOK {
203 t.Cache.Delete(cacheKey)
204 }
205 if err != nil {
206 return nil, err
207 }
208 }
209 } else {
210 reqCacheControl := parseCacheControl(req.Header)
211 if _, ok := reqCacheControl["only-if-cached"]; ok {
212 resp = newGatewayTimeoutResponse(req)
213 } else {
214 resp, err = transport.RoundTrip(req)
215 if err != nil {
216 return nil, err
217 }
218 }
219 }
220
221 if cacheable && canStore(parseCacheControl(req.Header), parseCacheControl(resp.Header)) {
222 for _, varyKey := range headerAllCommaSepValues(resp.Header, "vary") {
223 varyKey = http.CanonicalHeaderKey(varyKey)
224 fakeHeader := "X-Varied-" + varyKey
225 reqValue := req.Header.Get(varyKey)
226 if reqValue != "" {
227 resp.Header.Set(fakeHeader, reqValue)
228 }
229 }
230 switch req.Method {
231 case "GET":
232
233 resp.Body = &cachingReadCloser{
234 R: resp.Body,
235 OnEOF: func(r io.Reader) {
236 resp := *resp
237 resp.Body = ioutil.NopCloser(r)
238 respBytes, err := httputil.DumpResponse(&resp, true)
239 if err == nil {
240 t.Cache.Set(cacheKey, respBytes)
241 }
242 },
243 }
244 default:
245 respBytes, err := httputil.DumpResponse(resp, true)
246 if err == nil {
247 t.Cache.Set(cacheKey, respBytes)
248 }
249 }
250 } else {
251 t.Cache.Delete(cacheKey)
252 }
253 return resp, nil
254 }
255
256
257 var ErrNoDateHeader = errors.New("no Date header")
258
259
260 func Date(respHeaders http.Header) (date time.Time, err error) {
261 dateHeader := respHeaders.Get("date")
262 if dateHeader == "" {
263 err = ErrNoDateHeader
264 return
265 }
266
267 return time.Parse(time.RFC1123, dateHeader)
268 }
269
270 type realClock struct{}
271
272 func (c *realClock) since(d time.Time) time.Duration {
273 return time.Since(d)
274 }
275
276 type timer interface {
277 since(d time.Time) time.Duration
278 }
279
280 var clock timer = &realClock{}
281
282
283
284
285
286
287
288
289
290
291 func getFreshness(respHeaders, reqHeaders http.Header) (freshness int) {
292 respCacheControl := parseCacheControl(respHeaders)
293 reqCacheControl := parseCacheControl(reqHeaders)
294 if _, ok := reqCacheControl["no-cache"]; ok {
295 return transparent
296 }
297 if _, ok := respCacheControl["no-cache"]; ok {
298 return stale
299 }
300 if _, ok := reqCacheControl["only-if-cached"]; ok {
301 return fresh
302 }
303
304 date, err := Date(respHeaders)
305 if err != nil {
306 return stale
307 }
308 currentAge := clock.since(date)
309
310 var lifetime time.Duration
311 var zeroDuration time.Duration
312
313
314
315 if maxAge, ok := respCacheControl["max-age"]; ok {
316 lifetime, err = time.ParseDuration(maxAge + "s")
317 if err != nil {
318 lifetime = zeroDuration
319 }
320 } else {
321 expiresHeader := respHeaders.Get("Expires")
322 if expiresHeader != "" {
323 expires, err := time.Parse(time.RFC1123, expiresHeader)
324 if err != nil {
325 lifetime = zeroDuration
326 } else {
327 lifetime = expires.Sub(date)
328 }
329 }
330 }
331
332 if maxAge, ok := reqCacheControl["max-age"]; ok {
333
334 lifetime, err = time.ParseDuration(maxAge + "s")
335 if err != nil {
336 lifetime = zeroDuration
337 }
338 }
339 if minfresh, ok := reqCacheControl["min-fresh"]; ok {
340
341 minfreshDuration, err := time.ParseDuration(minfresh + "s")
342 if err == nil {
343 currentAge = time.Duration(currentAge + minfreshDuration)
344 }
345 }
346
347 if maxstale, ok := reqCacheControl["max-stale"]; ok {
348
349
350
351
352
353
354
355
356 if maxstale == "" {
357 return fresh
358 }
359 maxstaleDuration, err := time.ParseDuration(maxstale + "s")
360 if err == nil {
361 currentAge = time.Duration(currentAge - maxstaleDuration)
362 }
363 }
364
365 if lifetime > currentAge {
366 return fresh
367 }
368
369 return stale
370 }
371
372
373
374 func canStaleOnError(respHeaders, reqHeaders http.Header) bool {
375 respCacheControl := parseCacheControl(respHeaders)
376 reqCacheControl := parseCacheControl(reqHeaders)
377
378 var err error
379 lifetime := time.Duration(-1)
380
381 if staleMaxAge, ok := respCacheControl["stale-if-error"]; ok {
382 if staleMaxAge != "" {
383 lifetime, err = time.ParseDuration(staleMaxAge + "s")
384 if err != nil {
385 return false
386 }
387 } else {
388 return true
389 }
390 }
391 if staleMaxAge, ok := reqCacheControl["stale-if-error"]; ok {
392 if staleMaxAge != "" {
393 lifetime, err = time.ParseDuration(staleMaxAge + "s")
394 if err != nil {
395 return false
396 }
397 } else {
398 return true
399 }
400 }
401
402 if lifetime >= 0 {
403 date, err := Date(respHeaders)
404 if err != nil {
405 return false
406 }
407 currentAge := clock.since(date)
408 if lifetime > currentAge {
409 return true
410 }
411 }
412
413 return false
414 }
415
416 func getEndToEndHeaders(respHeaders http.Header) []string {
417
418 hopByHopHeaders := map[string]struct{}{
419 "Connection": {},
420 "Keep-Alive": {},
421 "Proxy-Authenticate": {},
422 "Proxy-Authorization": {},
423 "Te": {},
424 "Trailers": {},
425 "Transfer-Encoding": {},
426 "Upgrade": {},
427 }
428
429 for _, extra := range strings.Split(respHeaders.Get("connection"), ",") {
430
431 if strings.Trim(extra, " ") != "" {
432 hopByHopHeaders[http.CanonicalHeaderKey(extra)] = struct{}{}
433 }
434 }
435 endToEndHeaders := []string{}
436 for respHeader := range respHeaders {
437 if _, ok := hopByHopHeaders[respHeader]; !ok {
438 endToEndHeaders = append(endToEndHeaders, respHeader)
439 }
440 }
441 return endToEndHeaders
442 }
443
444 func canStore(reqCacheControl, respCacheControl cacheControl) (canStore bool) {
445 if _, ok := respCacheControl["no-store"]; ok {
446 return false
447 }
448 if _, ok := reqCacheControl["no-store"]; ok {
449 return false
450 }
451 return true
452 }
453
454 func newGatewayTimeoutResponse(req *http.Request) *http.Response {
455 var braw bytes.Buffer
456 braw.WriteString("HTTP/1.1 504 Gateway Timeout\r\n\r\n")
457 resp, err := http.ReadResponse(bufio.NewReader(&braw), req)
458 if err != nil {
459 panic(err)
460 }
461 return resp
462 }
463
464
465
466
467 func cloneRequest(r *http.Request) *http.Request {
468
469 r2 := new(http.Request)
470 *r2 = *r
471
472 r2.Header = make(http.Header)
473 for k, s := range r.Header {
474 r2.Header[k] = s
475 }
476 return r2
477 }
478
479 type cacheControl map[string]string
480
481 func parseCacheControl(headers http.Header) cacheControl {
482 cc := cacheControl{}
483 ccHeader := headers.Get("Cache-Control")
484 for _, part := range strings.Split(ccHeader, ",") {
485 part = strings.Trim(part, " ")
486 if part == "" {
487 continue
488 }
489 if strings.ContainsRune(part, '=') {
490 keyval := strings.Split(part, "=")
491 cc[strings.Trim(keyval[0], " ")] = strings.Trim(keyval[1], ",")
492 } else {
493 cc[part] = ""
494 }
495 }
496 return cc
497 }
498
499
500
501
502
503
504
505 func headerAllCommaSepValues(headers http.Header, name string) []string {
506 var vals []string
507 for _, val := range headers[http.CanonicalHeaderKey(name)] {
508 fields := strings.Split(val, ",")
509 for i, f := range fields {
510 fields[i] = strings.TrimSpace(f)
511 }
512 vals = append(vals, fields...)
513 }
514 return vals
515 }
516
517
518
519
520 type cachingReadCloser struct {
521
522 R io.ReadCloser
523
524 OnEOF func(io.Reader)
525
526 buf bytes.Buffer
527 }
528
529
530
531
532
533 func (r *cachingReadCloser) Read(p []byte) (n int, err error) {
534 n, err = r.R.Read(p)
535 r.buf.Write(p[:n])
536 if err == io.EOF {
537 r.OnEOF(bytes.NewReader(r.buf.Bytes()))
538 }
539 return n, err
540 }
541
542 func (r *cachingReadCloser) Close() error {
543 return r.R.Close()
544 }
545
546
547 func NewMemoryCacheTransport() *Transport {
548 c := NewMemoryCache()
549 t := NewTransport(c)
550 return t
551 }
552
View as plain text