...
1 package keyfunc
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "net/http"
8 "sync"
9 "time"
10 )
11
12 var (
13
14
15 defaultRefreshTimeout = time.Minute
16 )
17
18
19 func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
20 jwks = &JWKS{
21 jwksURL: jwksURL,
22 }
23
24 applyOptions(jwks, options)
25
26 if jwks.client == nil {
27 jwks.client = http.DefaultClient
28 }
29 if jwks.requestFactory == nil {
30 jwks.requestFactory = defaultRequestFactory
31 }
32 if jwks.responseExtractor == nil {
33 jwks.responseExtractor = ResponseExtractorStatusOK
34 }
35 if jwks.refreshTimeout == 0 {
36 jwks.refreshTimeout = defaultRefreshTimeout
37 }
38 if !options.JWKUseNoWhitelist && len(jwks.jwkUseWhitelist) == 0 {
39 jwks.jwkUseWhitelist = map[JWKUse]struct{}{
40 UseOmitted: {},
41 UseSignature: {},
42 }
43 }
44
45 err = jwks.refresh()
46 if err != nil {
47 return nil, err
48 }
49
50 if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
51 jwks.ctx, jwks.cancel = context.WithCancel(context.Background())
52 jwks.refreshRequests = make(chan context.CancelFunc, 1)
53 go jwks.backgroundRefresh()
54 }
55
56 return jwks, nil
57 }
58
59
60
61 func (j *JWKS) backgroundRefresh() {
62 var lastRefresh time.Time
63 var queueOnce sync.Once
64 var refreshMux sync.Mutex
65 if j.refreshRateLimit != 0 {
66 lastRefresh = time.Now().Add(-j.refreshRateLimit)
67 }
68
69
70 refreshInterval := make(<-chan time.Time)
71
72
73 for {
74 if j.refreshInterval != 0 {
75 refreshInterval = time.After(j.refreshInterval)
76 }
77
78 select {
79 case <-refreshInterval:
80 select {
81 case <-j.ctx.Done():
82 return
83 case j.refreshRequests <- func() {}:
84 default:
85 }
86
87 case cancel := <-j.refreshRequests:
88 refreshMux.Lock()
89 if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {
90
91 cancel()
92
93
94 queueOnce.Do(func() {
95 go func() {
96 refreshMux.Lock()
97 wait := time.Until(lastRefresh.Add(j.refreshRateLimit))
98 refreshMux.Unlock()
99 select {
100 case <-j.ctx.Done():
101 return
102 case <-time.After(wait):
103 }
104
105 refreshMux.Lock()
106 defer refreshMux.Unlock()
107 err := j.refresh()
108 if err != nil && j.refreshErrorHandler != nil {
109 j.refreshErrorHandler(err)
110 }
111
112 lastRefresh = time.Now()
113 queueOnce = sync.Once{}
114 }()
115 })
116 } else {
117 err := j.refresh()
118 if err != nil && j.refreshErrorHandler != nil {
119 j.refreshErrorHandler(err)
120 }
121
122 lastRefresh = time.Now()
123
124
125 cancel()
126 }
127 refreshMux.Unlock()
128
129
130 case <-j.ctx.Done():
131 return
132 }
133 }
134 }
135
136 func defaultRequestFactory(ctx context.Context, url string) (*http.Request, error) {
137 return http.NewRequestWithContext(ctx, http.MethodGet, url, bytes.NewReader(nil))
138 }
139
140
141 func (j *JWKS) refresh() (err error) {
142 var ctx context.Context
143 var cancel context.CancelFunc
144 if j.ctx != nil {
145 ctx, cancel = context.WithTimeout(j.ctx, j.refreshTimeout)
146 } else {
147 ctx, cancel = context.WithTimeout(context.Background(), j.refreshTimeout)
148 }
149 defer cancel()
150
151 req, err := j.requestFactory(ctx, j.jwksURL)
152 if err != nil {
153 return fmt.Errorf("failed to create request via factory function: %w", err)
154 }
155
156 resp, err := j.client.Do(req)
157 if err != nil {
158 return err
159 }
160
161 jwksBytes, err := j.responseExtractor(ctx, resp)
162 if err != nil {
163 return fmt.Errorf("failed to extract response via extractor function: %w", err)
164 }
165
166
167 if len(jwksBytes) != 0 && bytes.Equal(jwksBytes, j.raw) {
168 return nil
169 }
170 j.raw = jwksBytes
171
172 updated, err := NewJSON(jwksBytes)
173 if err != nil {
174 return err
175 }
176
177 j.mux.Lock()
178 defer j.mux.Unlock()
179 j.keys = updated.keys
180
181 if j.givenKeys != nil {
182 for kid, key := range j.givenKeys {
183
184 if !j.givenKIDOverride {
185 if _, ok := j.keys[kid]; ok {
186 continue
187 }
188 }
189
190 j.keys[kid] = parsedJWK{public: key.inter}
191 }
192 }
193
194 return nil
195 }
196
View as plain text