...

Source file src/github.com/cli/go-gh/v2/pkg/api/cache.go

Documentation: github.com/cli/go-gh/v2/pkg/api

     1  package api
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/sha256"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"os"
    12  	"path/filepath"
    13  	"strings"
    14  	"sync"
    15  	"time"
    16  )
    17  
    18  type cache struct {
    19  	dir string
    20  	ttl time.Duration
    21  }
    22  
    23  type cacheRoundTripper struct {
    24  	fs fileStorage
    25  	rt http.RoundTripper
    26  }
    27  
    28  type fileStorage struct {
    29  	dir string
    30  	ttl time.Duration
    31  	mu  *sync.RWMutex
    32  }
    33  
    34  type readCloser struct {
    35  	io.Reader
    36  	io.Closer
    37  }
    38  
    39  func isCacheableRequest(req *http.Request) bool {
    40  	if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
    41  		return true
    42  	}
    43  
    44  	if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
    45  		return true
    46  	}
    47  
    48  	return false
    49  }
    50  
    51  func isCacheableResponse(res *http.Response) bool {
    52  	return res.StatusCode < 500 && res.StatusCode != 403
    53  }
    54  
    55  func cacheKey(req *http.Request) (string, error) {
    56  	h := sha256.New()
    57  	fmt.Fprintf(h, "%s:", req.Method)
    58  	fmt.Fprintf(h, "%s:", req.URL.String())
    59  	fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
    60  	fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
    61  
    62  	if req.Body != nil {
    63  		var bodyCopy io.ReadCloser
    64  		req.Body, bodyCopy = copyStream(req.Body)
    65  		defer bodyCopy.Close()
    66  		if _, err := io.Copy(h, bodyCopy); err != nil {
    67  			return "", err
    68  		}
    69  	}
    70  
    71  	digest := h.Sum(nil)
    72  	return fmt.Sprintf("%x", digest), nil
    73  }
    74  
    75  func (c cache) RoundTripper(rt http.RoundTripper) http.RoundTripper {
    76  	fs := fileStorage{
    77  		dir: c.dir,
    78  		ttl: c.ttl,
    79  		mu:  &sync.RWMutex{},
    80  	}
    81  	return cacheRoundTripper{fs: fs, rt: rt}
    82  }
    83  
    84  func (crt cacheRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    85  	reqDir, reqTTL := requestCacheOptions(req)
    86  
    87  	if crt.fs.ttl == 0 && reqTTL == 0 {
    88  		return crt.rt.RoundTrip(req)
    89  	}
    90  
    91  	if !isCacheableRequest(req) {
    92  		return crt.rt.RoundTrip(req)
    93  	}
    94  
    95  	origDir := crt.fs.dir
    96  	if reqDir != "" {
    97  		crt.fs.dir = reqDir
    98  	}
    99  	origTTL := crt.fs.ttl
   100  	if reqTTL != 0 {
   101  		crt.fs.ttl = reqTTL
   102  	}
   103  
   104  	key, keyErr := cacheKey(req)
   105  	if keyErr == nil {
   106  		if res, err := crt.fs.read(key); err == nil {
   107  			res.Request = req
   108  			return res, nil
   109  		}
   110  	}
   111  
   112  	res, err := crt.rt.RoundTrip(req)
   113  	if err == nil && keyErr == nil && isCacheableResponse(res) {
   114  		_ = crt.fs.store(key, res)
   115  	}
   116  
   117  	crt.fs.dir = origDir
   118  	crt.fs.ttl = origTTL
   119  
   120  	return res, err
   121  }
   122  
   123  // Allow an individual request to override cache options.
   124  func requestCacheOptions(req *http.Request) (string, time.Duration) {
   125  	var dur time.Duration
   126  	dir := req.Header.Get("X-GH-CACHE-DIR")
   127  	ttl := req.Header.Get("X-GH-CACHE-TTL")
   128  	if ttl != "" {
   129  		dur, _ = time.ParseDuration(ttl)
   130  	}
   131  	return dir, dur
   132  }
   133  
   134  func (fs *fileStorage) filePath(key string) string {
   135  	if len(key) >= 6 {
   136  		return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
   137  	}
   138  	return filepath.Join(fs.dir, key)
   139  }
   140  
   141  func (fs *fileStorage) read(key string) (*http.Response, error) {
   142  	cacheFile := fs.filePath(key)
   143  
   144  	fs.mu.RLock()
   145  	defer fs.mu.RUnlock()
   146  
   147  	f, err := os.Open(cacheFile)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	defer f.Close()
   152  
   153  	stat, err := f.Stat()
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  
   158  	age := time.Since(stat.ModTime())
   159  	if age > fs.ttl {
   160  		return nil, errors.New("cache expired")
   161  	}
   162  
   163  	body := &bytes.Buffer{}
   164  	_, err = io.Copy(body, f)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	res, err := http.ReadResponse(bufio.NewReader(body), nil)
   170  	return res, err
   171  }
   172  
   173  func (fs *fileStorage) store(key string, res *http.Response) (storeErr error) {
   174  	cacheFile := fs.filePath(key)
   175  
   176  	fs.mu.Lock()
   177  	defer fs.mu.Unlock()
   178  
   179  	if storeErr = os.MkdirAll(filepath.Dir(cacheFile), 0755); storeErr != nil {
   180  		return
   181  	}
   182  
   183  	var f *os.File
   184  	if f, storeErr = os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600); storeErr != nil {
   185  		return
   186  	}
   187  
   188  	defer func() {
   189  		if err := f.Close(); storeErr == nil && err != nil {
   190  			storeErr = err
   191  		}
   192  	}()
   193  
   194  	var origBody io.ReadCloser
   195  	if res.Body != nil {
   196  		origBody, res.Body = copyStream(res.Body)
   197  		defer res.Body.Close()
   198  	}
   199  
   200  	storeErr = res.Write(f)
   201  	if origBody != nil {
   202  		res.Body = origBody
   203  	}
   204  
   205  	return
   206  }
   207  
   208  func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
   209  	b := &bytes.Buffer{}
   210  	nr := io.TeeReader(r, b)
   211  	return io.NopCloser(b), &readCloser{
   212  		Reader: nr,
   213  		Closer: r,
   214  	}
   215  }
   216  

View as plain text