...
1
2
3
4
5 package idtoken
6
7 import (
8 "context"
9 "encoding/json"
10 "fmt"
11 "net/http"
12 "strconv"
13 "strings"
14 "sync"
15 "time"
16 )
17
18 type cachingClient struct {
19 client *http.Client
20
21
22
23 clock func() time.Time
24
25 mu sync.Mutex
26 certs map[string]*cachedResponse
27 }
28
29 func newCachingClient(client *http.Client) *cachingClient {
30 return &cachingClient{
31 client: client,
32 certs: make(map[string]*cachedResponse, 2),
33 }
34 }
35
36 type cachedResponse struct {
37 resp *certResponse
38 exp time.Time
39 }
40
41 func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
42 if response, ok := c.get(url); ok {
43 return response, nil
44 }
45 req, err := http.NewRequest(http.MethodGet, url, nil)
46 if err != nil {
47 return nil, err
48 }
49 req = req.WithContext(ctx)
50 resp, err := c.client.Do(req)
51 if err != nil {
52 return nil, err
53 }
54 defer resp.Body.Close()
55 if resp.StatusCode != http.StatusOK {
56 return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
57 }
58
59 certResp := &certResponse{}
60 if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
61 return nil, err
62
63 }
64 c.set(url, certResp, resp.Header)
65 return certResp, nil
66 }
67
68 func (c *cachingClient) now() time.Time {
69 if c.clock != nil {
70 return c.clock()
71 }
72 return time.Now()
73 }
74
75 func (c *cachingClient) get(url string) (*certResponse, bool) {
76 c.mu.Lock()
77 defer c.mu.Unlock()
78 cachedResp, ok := c.certs[url]
79 if !ok {
80 return nil, false
81 }
82 if c.now().After(cachedResp.exp) {
83 return nil, false
84 }
85 return cachedResp.resp, true
86 }
87
88 func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
89 exp := c.calculateExpireTime(headers)
90 c.mu.Lock()
91 c.certs[url] = &cachedResponse{resp: resp, exp: exp}
92 c.mu.Unlock()
93 }
94
95
96
97
98 func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
99 var maxAge int
100 cc := strings.Split(headers.Get("cache-control"), ",")
101 for _, v := range cc {
102 if strings.Contains(v, "max-age") {
103 ss := strings.Split(v, "=")
104 if len(ss) < 2 {
105 return c.now()
106 }
107 ma, err := strconv.Atoi(ss[1])
108 if err != nil {
109 return c.now()
110 }
111 maxAge = ma
112 }
113 }
114 a := headers.Get("age")
115 if a == "" {
116 return c.now().Add(time.Duration(maxAge) * time.Second)
117 }
118 age, err := strconv.Atoi(a)
119 if err != nil {
120 return c.now()
121 }
122 return c.now().Add(time.Duration(maxAge-age) * time.Second)
123 }
124
View as plain text