1 package jwk
2
3 import (
4 "context"
5 "net/http"
6 "reflect"
7 "sync"
8 "time"
9
10 "github.com/lestrrat-go/backoff/v2"
11 "github.com/lestrrat-go/httpcc"
12 "github.com/pkg/errors"
13 )
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28 type AutoRefresh struct {
29 errSink chan AutoRefreshError
30 cache map[string]Set
31 configureCh chan struct{}
32 removeCh chan removeReq
33 fetching map[string]chan struct{}
34 muErrSink sync.Mutex
35 muCache sync.RWMutex
36 muFetching sync.Mutex
37 muRegistry sync.RWMutex
38 registry map[string]*target
39 resetTimerCh chan *resetTimerReq
40 }
41
42 type target struct {
43
44 backoff backoff.Policy
45
46
47
48 httpcl HTTPClient
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 refreshInterval *time.Duration
75 minRefreshInterval time.Duration
76
77 url string
78
79
80
81 timer *time.Timer
82
83
84 sem chan struct{}
85
86
87 lastRefresh time.Time
88 nextRefresh time.Time
89
90 wl Whitelist
91 parseOptions []ParseOption
92 }
93
94 type resetTimerReq struct {
95 t *target
96 d time.Duration
97 }
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115 func NewAutoRefresh(ctx context.Context) *AutoRefresh {
116 af := &AutoRefresh{
117 cache: make(map[string]Set),
118 configureCh: make(chan struct{}),
119 removeCh: make(chan removeReq),
120 fetching: make(map[string]chan struct{}),
121 registry: make(map[string]*target),
122 resetTimerCh: make(chan *resetTimerReq),
123 }
124 go af.refreshLoop(ctx)
125 return af
126 }
127
128 func (af *AutoRefresh) getCached(url string) (Set, bool) {
129 af.muCache.RLock()
130 ks, ok := af.cache[url]
131 af.muCache.RUnlock()
132 if ok {
133 return ks, true
134 }
135 return nil, false
136 }
137
138 type removeReq struct {
139 replyCh chan error
140 url string
141 }
142
143
144
145 func (af *AutoRefresh) Remove(url string) error {
146 ch := make(chan error)
147 af.removeCh <- removeReq{replyCh: ch, url: url}
148 return <-ch
149 }
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167 func (af *AutoRefresh) Configure(url string, options ...AutoRefreshOption) {
168 var httpcl HTTPClient = http.DefaultClient
169 var hasRefreshInterval bool
170 var refreshInterval time.Duration
171 var wl Whitelist
172 var parseOptions []ParseOption
173 minRefreshInterval := time.Hour
174 bo := backoff.Null()
175 for _, option := range options {
176 if v, ok := option.(ParseOption); ok {
177 parseOptions = append(parseOptions, v)
178 continue
179 }
180
181
182 switch option.Ident() {
183 case identFetchBackoff{}:
184 bo = option.Value().(backoff.Policy)
185 case identRefreshInterval{}:
186 refreshInterval = option.Value().(time.Duration)
187 hasRefreshInterval = true
188 case identMinRefreshInterval{}:
189 minRefreshInterval = option.Value().(time.Duration)
190 case identHTTPClient{}:
191 httpcl = option.Value().(HTTPClient)
192 case identFetchWhitelist{}:
193 wl = option.Value().(Whitelist)
194 }
195 }
196
197 af.muRegistry.Lock()
198 t, ok := af.registry[url]
199 if ok {
200 if t.httpcl != httpcl {
201 t.httpcl = httpcl
202 }
203
204 if t.minRefreshInterval != minRefreshInterval {
205 t.minRefreshInterval = minRefreshInterval
206 }
207
208 if t.refreshInterval != nil {
209 if !hasRefreshInterval {
210 t.refreshInterval = nil
211 } else if *t.refreshInterval != refreshInterval {
212 *t.refreshInterval = refreshInterval
213 }
214 } else {
215 if hasRefreshInterval {
216 t.refreshInterval = &refreshInterval
217 }
218 }
219
220 if t.wl != wl {
221 t.wl = wl
222 }
223
224 t.parseOptions = parseOptions
225 } else {
226 t = &target{
227 backoff: bo,
228 httpcl: httpcl,
229 minRefreshInterval: minRefreshInterval,
230 url: url,
231 sem: make(chan struct{}, 1),
232
233
234
235 timer: time.NewTimer(24 * time.Hour),
236 wl: wl,
237 parseOptions: parseOptions,
238 }
239 if hasRefreshInterval {
240 t.refreshInterval = &refreshInterval
241 }
242
243
244 af.registry[url] = t
245 }
246 af.muRegistry.Unlock()
247
248
249 af.configureCh <- struct{}{}
250 }
251
252 func (af *AutoRefresh) releaseFetching(url string) {
253
254
255 af.muFetching.Lock()
256 fetchingCh, ok := af.fetching[url]
257 if !ok {
258
259 af.muFetching.Unlock()
260 return
261 }
262 delete(af.fetching, url)
263 close(fetchingCh)
264 af.muFetching.Unlock()
265 }
266
267
268 func (af *AutoRefresh) IsRegistered(url string) bool {
269 _, ok := af.getRegistered(url)
270 return ok
271 }
272
273
274 func (af *AutoRefresh) getRegistered(url string) (*target, bool) {
275 af.muRegistry.RLock()
276 t, ok := af.registry[url]
277 af.muRegistry.RUnlock()
278 return t, ok
279 }
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295 func (af *AutoRefresh) Fetch(ctx context.Context, url string) (Set, error) {
296 if _, ok := af.getRegistered(url); !ok {
297 return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url)
298 }
299
300 ks, found := af.getCached(url)
301 if found {
302 return ks, nil
303 }
304
305 return af.refresh(ctx, url)
306 }
307
308
309
310
311
312
313 func (af *AutoRefresh) Refresh(ctx context.Context, url string) (Set, error) {
314 if _, ok := af.getRegistered(url); !ok {
315 return nil, errors.Errorf(`url %s must be configured using "Configure()" first`, url)
316 }
317
318 return af.refresh(ctx, url)
319 }
320
321 func (af *AutoRefresh) refresh(ctx context.Context, url string) (Set, error) {
322
323
324 af.muFetching.Lock()
325 fetchingCh, fetching := af.fetching[url]
326
327
328 if fetching {
329 af.muFetching.Unlock()
330 select {
331 case <-ctx.Done():
332 return nil, ctx.Err()
333 case <-fetchingCh:
334 }
335 } else {
336 fetchingCh = make(chan struct{})
337 af.fetching[url] = fetchingCh
338 af.muFetching.Unlock()
339
340
341 defer af.releaseFetching(url)
342
343
344 if err := af.doRefreshRequest(ctx, url, false); err != nil {
345 return nil, errors.Wrapf(err, `failed to fetch resource pointed by %s`, url)
346 }
347 }
348
349
350 ks, ok := af.getCached(url)
351 if !ok {
352 return nil, errors.New("cache was not populated after explicit refresh")
353 }
354
355 return ks, nil
356 }
357
358
359 func (af *AutoRefresh) refreshLoop(ctx context.Context) {
360
361
362
363
364
365 const (
366 ctxDoneIdx = iota
367 configureIdx
368 resetTimerIdx
369 removeIdx
370 baseSelcasesLen
371 )
372
373 baseSelcases := make([]reflect.SelectCase, baseSelcasesLen)
374 baseSelcases[ctxDoneIdx] = reflect.SelectCase{
375 Dir: reflect.SelectRecv,
376 Chan: reflect.ValueOf(ctx.Done()),
377 }
378 baseSelcases[configureIdx] = reflect.SelectCase{
379 Dir: reflect.SelectRecv,
380 Chan: reflect.ValueOf(af.configureCh),
381 }
382 baseSelcases[resetTimerIdx] = reflect.SelectCase{
383 Dir: reflect.SelectRecv,
384 Chan: reflect.ValueOf(af.resetTimerCh),
385 }
386 baseSelcases[removeIdx] = reflect.SelectCase{
387 Dir: reflect.SelectRecv,
388 Chan: reflect.ValueOf(af.removeCh),
389 }
390
391 var targets []*target
392 var selcases []reflect.SelectCase
393 for {
394
395
396
397 af.muRegistry.RLock()
398 if cap(targets) < len(af.registry) {
399 targets = make([]*target, 0, len(af.registry))
400 } else {
401 targets = targets[:0]
402 }
403
404 if cap(selcases) < len(af.registry) {
405 selcases = make([]reflect.SelectCase, 0, len(af.registry)+baseSelcasesLen)
406 } else {
407 selcases = selcases[:0]
408 }
409 selcases = append(selcases, baseSelcases...)
410
411 for _, data := range af.registry {
412 targets = append(targets, data)
413 selcases = append(selcases, reflect.SelectCase{
414 Dir: reflect.SelectRecv,
415 Chan: reflect.ValueOf(data.timer.C),
416 })
417 }
418 af.muRegistry.RUnlock()
419
420 chosen, recv, recvOK := reflect.Select(selcases)
421 switch chosen {
422 case ctxDoneIdx:
423
424 return
425 case configureIdx:
426
427
428
429 continue
430 case resetTimerIdx:
431
432
433 if !recvOK {
434 continue
435 }
436
437 req := recv.Interface().(*resetTimerReq)
438 t := req.t
439 d := req.d
440 if !t.timer.Stop() {
441 select {
442 case <-t.timer.C:
443 default:
444 }
445 }
446 t.timer.Reset(d)
447 case removeIdx:
448
449
450 req := recv.Interface().(removeReq)
451 replyCh := req.replyCh
452 url := req.url
453 af.muRegistry.Lock()
454 if _, ok := af.registry[url]; !ok {
455 replyCh <- errors.Errorf(`invalid url %q (not registered)`, url)
456 } else {
457 delete(af.registry, url)
458 replyCh <- nil
459 }
460 af.muRegistry.Unlock()
461 default:
462
463 if !recvOK {
464 continue
465 }
466
467
468 t := targets[chosen-baseSelcasesLen]
469
470
471
472
473 select {
474 case t.sem <- struct{}{}:
475
476 default:
477 continue
478 }
479
480 go func() {
481
482 af.doRefreshRequest(ctx, t.url, true)
483 <-t.sem
484 }()
485 }
486 }
487 }
488
489 func (af *AutoRefresh) doRefreshRequest(ctx context.Context, url string, enableBackoff bool) error {
490 af.muRegistry.RLock()
491 t, ok := af.registry[url]
492
493 if !ok {
494 af.muRegistry.RUnlock()
495 return errors.Errorf(`url "%s" is not registered`, url)
496 }
497
498
499
500 parseOptions := t.parseOptions
501 fetchOptions := []FetchOption{WithHTTPClient(t.httpcl)}
502 if enableBackoff {
503 fetchOptions = append(fetchOptions, WithFetchBackoff(t.backoff))
504 }
505 if t.wl != nil {
506 fetchOptions = append(fetchOptions, WithFetchWhitelist(t.wl))
507 }
508 af.muRegistry.RUnlock()
509
510 res, err := fetch(ctx, url, fetchOptions...)
511 if err == nil {
512 if res.StatusCode != http.StatusOK {
513
514
515 err = errors.Errorf(`bad response status code (%d)`, res.StatusCode)
516 } else {
517 defer res.Body.Close()
518 keyset, parseErr := ParseReader(res.Body, parseOptions...)
519 if parseErr == nil {
520
521 af.muCache.Lock()
522 af.cache[url] = keyset
523 af.muCache.Unlock()
524 af.muRegistry.RLock()
525 nextInterval := calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval)
526 af.muRegistry.RUnlock()
527 rtr := &resetTimerReq{
528 t: t,
529 d: nextInterval,
530 }
531 select {
532 case <-ctx.Done():
533 return ctx.Err()
534 case af.resetTimerCh <- rtr:
535 }
536
537 now := time.Now()
538 af.muRegistry.Lock()
539 t.lastRefresh = now.Local()
540 t.nextRefresh = now.Add(nextInterval).Local()
541 af.muRegistry.Unlock()
542 return nil
543 }
544 err = parseErr
545 }
546 }
547
548
549
550
551
552 if err != nil {
553 select {
554 case af.errSink <- AutoRefreshError{Error: err, URL: url}:
555 default:
556 }
557 }
558
559
560
561
562
563
564
565 rtr := &resetTimerReq{
566 t: t,
567 d: calculateRefreshDuration(res, t.refreshInterval, t.minRefreshInterval),
568 }
569 select {
570 case <-ctx.Done():
571 return ctx.Err()
572 case af.resetTimerCh <- rtr:
573 }
574
575 return err
576 }
577
578
579
580
581
582
583
584
585 func (af *AutoRefresh) ErrorSink(ch chan AutoRefreshError) {
586 af.muErrSink.Lock()
587 af.errSink = ch
588 af.muErrSink.Unlock()
589 }
590
591 func calculateRefreshDuration(res *http.Response, refreshInterval *time.Duration, minRefreshInterval time.Duration) time.Duration {
592
593 if refreshInterval != nil {
594 return *refreshInterval
595 }
596
597 if res != nil {
598 if v := res.Header.Get(`Cache-Control`); v != "" {
599 dir, err := httpcc.ParseResponse(v)
600 if err == nil {
601 maxAge, ok := dir.MaxAge()
602 if ok {
603 resDuration := time.Duration(maxAge) * time.Second
604 if resDuration > minRefreshInterval {
605 return resDuration
606 }
607 return minRefreshInterval
608 }
609
610 }
611
612 }
613
614 if v := res.Header.Get(`Expires`); v != "" {
615 expires, err := http.ParseTime(v)
616 if err == nil {
617 resDuration := time.Until(expires)
618 if resDuration > minRefreshInterval {
619 return resDuration
620 }
621 return minRefreshInterval
622 }
623
624 }
625 }
626
627
628 return minRefreshInterval
629 }
630
631
632
633
634 type TargetSnapshot struct {
635 URL string
636 NextRefresh time.Time
637 LastRefresh time.Time
638 }
639
640 func (af *AutoRefresh) Snapshot() <-chan TargetSnapshot {
641 af.muRegistry.Lock()
642 ch := make(chan TargetSnapshot, len(af.registry))
643 for url, t := range af.registry {
644 ch <- TargetSnapshot{
645 URL: url,
646 NextRefresh: t.nextRefresh,
647 LastRefresh: t.lastRefresh,
648 }
649 }
650 af.muRegistry.Unlock()
651 close(ch)
652 return ch
653 }
654
View as plain text