1 package api
2
3 import (
4 "bufio"
5 "bytes"
6 "crypto/sha256"
7 "errors"
8 "fmt"
9 "io"
10 "net/http"
11 "os"
12 "path/filepath"
13 "strings"
14 "sync"
15 "time"
16 )
17
18 type cache struct {
19 dir string
20 ttl time.Duration
21 }
22
23 type cacheRoundTripper struct {
24 fs fileStorage
25 rt http.RoundTripper
26 }
27
28 type fileStorage struct {
29 dir string
30 ttl time.Duration
31 mu *sync.RWMutex
32 }
33
34 type readCloser struct {
35 io.Reader
36 io.Closer
37 }
38
39 func isCacheableRequest(req *http.Request) bool {
40 if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
41 return true
42 }
43
44 if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
45 return true
46 }
47
48 return false
49 }
50
51 func isCacheableResponse(res *http.Response) bool {
52 return res.StatusCode < 500 && res.StatusCode != 403
53 }
54
55 func cacheKey(req *http.Request) (string, error) {
56 h := sha256.New()
57 fmt.Fprintf(h, "%s:", req.Method)
58 fmt.Fprintf(h, "%s:", req.URL.String())
59 fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
60 fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
61
62 if req.Body != nil {
63 var bodyCopy io.ReadCloser
64 req.Body, bodyCopy = copyStream(req.Body)
65 defer bodyCopy.Close()
66 if _, err := io.Copy(h, bodyCopy); err != nil {
67 return "", err
68 }
69 }
70
71 digest := h.Sum(nil)
72 return fmt.Sprintf("%x", digest), nil
73 }
74
75 func (c cache) RoundTripper(rt http.RoundTripper) http.RoundTripper {
76 fs := fileStorage{
77 dir: c.dir,
78 ttl: c.ttl,
79 mu: &sync.RWMutex{},
80 }
81 return cacheRoundTripper{fs: fs, rt: rt}
82 }
83
84 func (crt cacheRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
85 reqDir, reqTTL := requestCacheOptions(req)
86
87 if crt.fs.ttl == 0 && reqTTL == 0 {
88 return crt.rt.RoundTrip(req)
89 }
90
91 if !isCacheableRequest(req) {
92 return crt.rt.RoundTrip(req)
93 }
94
95 origDir := crt.fs.dir
96 if reqDir != "" {
97 crt.fs.dir = reqDir
98 }
99 origTTL := crt.fs.ttl
100 if reqTTL != 0 {
101 crt.fs.ttl = reqTTL
102 }
103
104 key, keyErr := cacheKey(req)
105 if keyErr == nil {
106 if res, err := crt.fs.read(key); err == nil {
107 res.Request = req
108 return res, nil
109 }
110 }
111
112 res, err := crt.rt.RoundTrip(req)
113 if err == nil && keyErr == nil && isCacheableResponse(res) {
114 _ = crt.fs.store(key, res)
115 }
116
117 crt.fs.dir = origDir
118 crt.fs.ttl = origTTL
119
120 return res, err
121 }
122
123
124 func requestCacheOptions(req *http.Request) (string, time.Duration) {
125 var dur time.Duration
126 dir := req.Header.Get("X-GH-CACHE-DIR")
127 ttl := req.Header.Get("X-GH-CACHE-TTL")
128 if ttl != "" {
129 dur, _ = time.ParseDuration(ttl)
130 }
131 return dir, dur
132 }
133
134 func (fs *fileStorage) filePath(key string) string {
135 if len(key) >= 6 {
136 return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
137 }
138 return filepath.Join(fs.dir, key)
139 }
140
141 func (fs *fileStorage) read(key string) (*http.Response, error) {
142 cacheFile := fs.filePath(key)
143
144 fs.mu.RLock()
145 defer fs.mu.RUnlock()
146
147 f, err := os.Open(cacheFile)
148 if err != nil {
149 return nil, err
150 }
151 defer f.Close()
152
153 stat, err := f.Stat()
154 if err != nil {
155 return nil, err
156 }
157
158 age := time.Since(stat.ModTime())
159 if age > fs.ttl {
160 return nil, errors.New("cache expired")
161 }
162
163 body := &bytes.Buffer{}
164 _, err = io.Copy(body, f)
165 if err != nil {
166 return nil, err
167 }
168
169 res, err := http.ReadResponse(bufio.NewReader(body), nil)
170 return res, err
171 }
172
173 func (fs *fileStorage) store(key string, res *http.Response) (storeErr error) {
174 cacheFile := fs.filePath(key)
175
176 fs.mu.Lock()
177 defer fs.mu.Unlock()
178
179 if storeErr = os.MkdirAll(filepath.Dir(cacheFile), 0755); storeErr != nil {
180 return
181 }
182
183 var f *os.File
184 if f, storeErr = os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); storeErr != nil {
185 return
186 }
187
188 defer func() {
189 if err := f.Close(); storeErr == nil && err != nil {
190 storeErr = err
191 }
192 }()
193
194 var origBody io.ReadCloser
195 if res.Body != nil {
196 origBody, res.Body = copyStream(res.Body)
197 defer res.Body.Close()
198 }
199
200 storeErr = res.Write(f)
201 if origBody != nil {
202 res.Body = origBody
203 }
204
205 return
206 }
207
208 func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
209 b := &bytes.Buffer{}
210 nr := io.TeeReader(r, b)
211 return io.NopCloser(b), &readCloser{
212 Reader: nr,
213 Closer: r,
214 }
215 }
216
View as plain text