...
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package idtoken
16
17 import (
18 "context"
19 "encoding/json"
20 "fmt"
21 "net/http"
22 "strconv"
23 "strings"
24 "sync"
25 "time"
26 )
27
28 type cachingClient struct {
29 client *http.Client
30
31
32
33 clock func() time.Time
34
35 mu sync.Mutex
36 certs map[string]*cachedResponse
37 }
38
39 func newCachingClient(client *http.Client) *cachingClient {
40 return &cachingClient{
41 client: client,
42 certs: make(map[string]*cachedResponse, 2),
43 }
44 }
45
46 type cachedResponse struct {
47 resp *certResponse
48 exp time.Time
49 }
50
51 func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
52 if response, ok := c.get(url); ok {
53 return response, nil
54 }
55 req, err := http.NewRequest(http.MethodGet, url, nil)
56 if err != nil {
57 return nil, err
58 }
59 req = req.WithContext(ctx)
60 resp, err := c.client.Do(req)
61 if err != nil {
62 return nil, err
63 }
64 defer resp.Body.Close()
65 if resp.StatusCode != http.StatusOK {
66 return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
67 }
68
69 certResp := &certResponse{}
70 if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
71 return nil, err
72
73 }
74 c.set(url, certResp, resp.Header)
75 return certResp, nil
76 }
77
78 func (c *cachingClient) now() time.Time {
79 if c.clock != nil {
80 return c.clock()
81 }
82 return time.Now()
83 }
84
85 func (c *cachingClient) get(url string) (*certResponse, bool) {
86 c.mu.Lock()
87 defer c.mu.Unlock()
88 cachedResp, ok := c.certs[url]
89 if !ok {
90 return nil, false
91 }
92 if c.now().After(cachedResp.exp) {
93 return nil, false
94 }
95 return cachedResp.resp, true
96 }
97
98 func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
99 exp := c.calculateExpireTime(headers)
100 c.mu.Lock()
101 c.certs[url] = &cachedResponse{resp: resp, exp: exp}
102 c.mu.Unlock()
103 }
104
105
106
107
108 func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
109 var maxAge int
110 cc := strings.Split(headers.Get("cache-control"), ",")
111 for _, v := range cc {
112 if strings.Contains(v, "max-age") {
113 ss := strings.Split(v, "=")
114 if len(ss) < 2 {
115 return c.now()
116 }
117 ma, err := strconv.Atoi(ss[1])
118 if err != nil {
119 return c.now()
120 }
121 maxAge = ma
122 }
123 }
124 a := headers.Get("age")
125 if a == "" {
126 return c.now().Add(time.Duration(maxAge) * time.Second)
127 }
128 age, err := strconv.Atoi(a)
129 if err != nil {
130 return c.now()
131 }
132 return c.now().Add(time.Duration(maxAge-age) * time.Second)
133 }
134
View as plain text