...
1 package keyfunc
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "net/http"
9 "sync"
10 "time"
11 )
12
13 var (
14
15
16 ErrRefreshImpossible = errors.New("refresh impossible: JWKS was not created from a remote resource")
17
18
19
20 defaultRefreshTimeout = time.Minute
21 )
22
23
24 func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
25 jwks = &JWKS{
26 jwksURL: jwksURL,
27 }
28
29 applyOptions(jwks, options)
30
31 if jwks.client == nil {
32 jwks.client = http.DefaultClient
33 }
34 if jwks.requestFactory == nil {
35 jwks.requestFactory = defaultRequestFactory
36 }
37 if jwks.responseExtractor == nil {
38 jwks.responseExtractor = ResponseExtractorStatusOK
39 }
40 if jwks.refreshTimeout == 0 {
41 jwks.refreshTimeout = defaultRefreshTimeout
42 }
43 if !options.JWKUseNoWhitelist && len(jwks.jwkUseWhitelist) == 0 {
44 jwks.jwkUseWhitelist = map[JWKUse]struct{}{
45 UseOmitted: {},
46 UseSignature: {},
47 }
48 }
49
50 err = jwks.refresh()
51 if err != nil {
52 if options.TolerateInitialJWKHTTPError {
53 if jwks.refreshErrorHandler != nil {
54 jwks.refreshErrorHandler(err)
55 }
56 jwks.keys = make(map[string]parsedJWK)
57 } else {
58 return nil, err
59 }
60 }
61
62 if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
63 if jwks.ctx == nil {
64 jwks.ctx = context.Background()
65 }
66 jwks.ctx, jwks.cancel = context.WithCancel(jwks.ctx)
67 jwks.refreshRequests = make(chan refreshRequest, 1)
68 go jwks.backgroundRefresh()
69 }
70
71 return jwks, nil
72 }
73
74
75
76
77
78
79 func (j *JWKS) Refresh(ctx context.Context, options RefreshOptions) error {
80 if j.jwksURL == "" {
81 return ErrRefreshImpossible
82 }
83
84
85 if j.refreshInterval != 0 || j.refreshUnknownKID {
86 ctx, cancel := context.WithCancel(ctx)
87
88 req := refreshRequest{
89 cancel: cancel,
90 ignoreRateLimit: options.IgnoreRateLimit,
91 }
92
93 select {
94 case <-ctx.Done():
95 return fmt.Errorf("failed to send request refresh to background goroutine: %w", j.ctx.Err())
96 case j.refreshRequests <- req:
97 }
98
99 <-ctx.Done()
100
101 if !errors.Is(ctx.Err(), context.Canceled) {
102 return fmt.Errorf("unexpected keyfunc background refresh context error: %w", ctx.Err())
103 }
104 } else {
105 err := j.refresh()
106 if err != nil {
107 return fmt.Errorf("failed to refresh JWKS: %w", err)
108 }
109 }
110
111 return nil
112 }
113
114
115
116 func (j *JWKS) backgroundRefresh() {
117 var lastRefresh time.Time
118 var queueOnce sync.Once
119 var refreshMux sync.Mutex
120 if j.refreshRateLimit != 0 {
121 lastRefresh = time.Now().Add(-j.refreshRateLimit)
122 }
123
124
125 refreshInterval := make(<-chan time.Time)
126
127 refresh := func() {
128 err := j.refresh()
129 if err != nil && j.refreshErrorHandler != nil {
130 j.refreshErrorHandler(err)
131 }
132 lastRefresh = time.Now()
133 }
134
135
136 for {
137 if j.refreshInterval != 0 {
138 refreshInterval = time.After(j.refreshInterval)
139 }
140
141 select {
142 case <-refreshInterval:
143 select {
144 case <-j.ctx.Done():
145 return
146 case j.refreshRequests <- refreshRequest{}:
147 default:
148 }
149
150 case req := <-j.refreshRequests:
151 refreshMux.Lock()
152 if req.ignoreRateLimit {
153 refresh()
154 } else if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {
155
156 queueOnce.Do(func() {
157 go func() {
158 refreshMux.Lock()
159 wait := time.Until(lastRefresh.Add(j.refreshRateLimit))
160 refreshMux.Unlock()
161 select {
162 case <-j.ctx.Done():
163 return
164 case <-time.After(wait):
165 }
166
167 refreshMux.Lock()
168 defer refreshMux.Unlock()
169 refresh()
170 queueOnce = sync.Once{}
171 }()
172 })
173 } else {
174 refresh()
175 }
176 if req.cancel != nil {
177 req.cancel()
178 }
179 refreshMux.Unlock()
180
181
182 case <-j.ctx.Done():
183 return
184 }
185 }
186 }
187
188 func defaultRequestFactory(ctx context.Context, url string) (*http.Request, error) {
189 return http.NewRequestWithContext(ctx, http.MethodGet, url, bytes.NewReader(nil))
190 }
191
192
193 func (j *JWKS) refresh() (err error) {
194 var ctx context.Context
195 var cancel context.CancelFunc
196 if j.ctx != nil {
197 ctx, cancel = context.WithTimeout(j.ctx, j.refreshTimeout)
198 } else {
199 ctx, cancel = context.WithTimeout(context.Background(), j.refreshTimeout)
200 }
201 defer cancel()
202
203 req, err := j.requestFactory(ctx, j.jwksURL)
204 if err != nil {
205 return fmt.Errorf("failed to create request via factory function: %w", err)
206 }
207
208 resp, err := j.client.Do(req)
209 if err != nil {
210 return err
211 }
212
213 jwksBytes, err := j.responseExtractor(ctx, resp)
214 if err != nil {
215 return fmt.Errorf("failed to extract response via extractor function: %w", err)
216 }
217
218
219 if len(jwksBytes) != 0 && bytes.Equal(jwksBytes, j.raw) {
220 return nil
221 }
222 j.raw = jwksBytes
223
224 updated, err := NewJSON(jwksBytes)
225 if err != nil {
226 return err
227 }
228
229 j.mux.Lock()
230 defer j.mux.Unlock()
231 j.keys = updated.keys
232
233 if j.givenKeys != nil {
234 for kid, key := range j.givenKeys {
235
236 if !j.givenKIDOverride {
237 if _, ok := j.keys[kid]; ok {
238 continue
239 }
240 }
241
242 j.keys[kid] = parsedJWK{public: key.inter}
243 }
244 }
245
246 return nil
247 }
248
View as plain text