1 package jwk_test
2
3 import (
4 "bytes"
5 "context"
6 "fmt"
7 "net/http"
8 "net/http/httptest"
9 "sync"
10 "testing"
11 "time"
12
13 "github.com/lestrrat-go/backoff/v2"
14 "github.com/lestrrat-go/iter/arrayiter"
15 "github.com/lestrrat-go/jwx/internal/json"
16 "github.com/lestrrat-go/jwx/internal/jwxtest"
17 "github.com/lestrrat-go/jwx/jwk"
18 "github.com/stretchr/testify/assert"
19 "github.com/stretchr/testify/require"
20 )
21
22
23 func checkAccessCount(t *testing.T, ctx context.Context, src arrayiter.Source, expected ...int) bool {
24 t.Helper()
25
26 iter := src.Iterate(ctx)
27 iter.Next(ctx)
28
29 key := iter.Pair().Value.(jwk.Key)
30 v, ok := key.Get(`accessCount`)
31 if !assert.True(t, ok, `key.Get("accessCount") should succeed`) {
32 return false
33 }
34
35 for _, e := range expected {
36 if v == float64(e) {
37 return assert.Equal(t, float64(e), v, `key.Get("accessCount") should be %d`, e)
38 }
39 }
40
41 var buf bytes.Buffer
42 fmt.Fprint(&buf, "[")
43 for i, e := range expected {
44 fmt.Fprintf(&buf, "%d", e)
45 if i < len(expected)-1 {
46 fmt.Fprint(&buf, ", ")
47 }
48 }
49 fmt.Fprintf(&buf, "]")
50 return assert.Failf(t, `key.Get("accessCount") should be one of %s (got %d)`, buf.String(), v)
51 }
52
53 func TestAutoRefresh(t *testing.T) {
54 t.Parallel()
55
56 t.Run("Specify explicit refresh interval", func(t *testing.T) {
57 t.Parallel()
58 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
59 defer cancel()
60
61 var accessCount int
62 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
63 accessCount++
64
65 key := map[string]interface{}{
66 "kty": "EC",
67 "crv": "P-256",
68 "x": "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
69 "y": "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
70 "accessCount": accessCount,
71 }
72 hdrs := w.Header()
73 hdrs.Set(`Content-Type`, `application/json`)
74 hdrs.Set(`Cache-Control`, `max-age=7200`)
75
76 json.NewEncoder(w).Encode(key)
77 }))
78 defer srv.Close()
79
80 af := jwk.NewAutoRefresh(ctx)
81 af.Configure(srv.URL, jwk.WithRefreshInterval(3*time.Second))
82
83 retries := 5
84
85 var wg sync.WaitGroup
86 wg.Add(retries)
87 for i := 0; i < retries; i++ {
88
89 go func() {
90 defer wg.Done()
91 ks, err := af.Fetch(ctx, srv.URL)
92 if !assert.NoError(t, err, `af.Fetch should succeed`) {
93 return
94 }
95 if !checkAccessCount(t, ctx, ks, 1) {
96 return
97 }
98 }()
99 }
100
101 t.Logf("Waiting for fetching goroutines...")
102 wg.Wait()
103 t.Logf("Waiting for the refresh ...")
104 time.Sleep(4 * time.Second)
105 ks, err := af.Fetch(ctx, srv.URL)
106 if !assert.NoError(t, err, `af.Fetch should succeed`) {
107 return
108 }
109 if !checkAccessCount(t, ctx, ks, 2) {
110 return
111 }
112 })
113 t.Run("Calculate next refresh from Cache-Control header", func(t *testing.T) {
114 t.Parallel()
115 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
116 defer cancel()
117
118 var accessCount int
119 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
120 accessCount++
121
122 key := map[string]interface{}{
123 "kty": "EC",
124 "crv": "P-256",
125 "x": "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
126 "y": "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
127 "accessCount": accessCount,
128 }
129 hdrs := w.Header()
130 hdrs.Set(`Content-Type`, `application/json`)
131 hdrs.Set(`Cache-Control`, `max-age=3`)
132
133 json.NewEncoder(w).Encode(key)
134 }))
135 defer srv.Close()
136
137 af := jwk.NewAutoRefresh(ctx)
138 af.Configure(srv.URL, jwk.WithMinRefreshInterval(time.Second))
139 if !assert.True(t, af.IsRegistered(srv.URL), `af.IsRegistered should be true`) {
140 return
141 }
142
143 retries := 5
144
145 var wg sync.WaitGroup
146 wg.Add(retries)
147 for i := 0; i < retries; i++ {
148
149 go func() {
150 defer wg.Done()
151 ks, err := af.Fetch(ctx, srv.URL)
152 if !assert.NoError(t, err, `af.Fetch should succeed`) {
153 return
154 }
155
156 if !checkAccessCount(t, ctx, ks, 1) {
157 return
158 }
159 }()
160 }
161
162 t.Logf("Waiting for fetching goroutines...")
163 wg.Wait()
164 t.Logf("Waiting for the refresh ...")
165 time.Sleep(4 * time.Second)
166 ks, err := af.Fetch(ctx, srv.URL)
167 if !assert.NoError(t, err, `af.Fetch should succeed`) {
168 return
169 }
170 if !checkAccessCount(t, ctx, ks, 2) {
171 return
172 }
173 })
174 t.Run("Backoff", func(t *testing.T) {
175 t.Parallel()
176 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
177 defer cancel()
178
179 var accessCount int
180 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
181 accessCount++
182 if accessCount > 1 && accessCount < 4 {
183 http.Error(w, "wait for it....", http.StatusForbidden)
184 return
185 }
186
187 key := map[string]interface{}{
188 "kty": "EC",
189 "crv": "P-256",
190 "x": "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
191 "y": "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
192 "accessCount": accessCount,
193 }
194 hdrs := w.Header()
195 hdrs.Set(`Content-Type`, `application/json`)
196 hdrs.Set(`Cache-Control`, `max-age=1`)
197
198 json.NewEncoder(w).Encode(key)
199 }))
200 defer srv.Close()
201
202 af := jwk.NewAutoRefresh(ctx)
203 bo := backoff.Constant(backoff.WithInterval(time.Second))
204 af.Configure(srv.URL, jwk.WithFetchBackoff(bo), jwk.WithMinRefreshInterval(1))
205
206
207 ks, err := af.Fetch(ctx, srv.URL)
208 if !assert.NoError(t, err, `af.Fetch (#1) should succed`) {
209 return
210 }
211 if !checkAccessCount(t, ctx, ks, 1) {
212 return
213 }
214
215
216 time.Sleep(1500 * time.Millisecond)
217 ks, err = af.Fetch(ctx, srv.URL)
218 if !assert.NoError(t, err, `af.Fetch (#2) should succeed`) {
219 return
220 }
221
222 if !checkAccessCount(t, ctx, ks, 1) {
223 return
224 }
225
226
227 time.Sleep(2500 * time.Millisecond)
228
229 ks, err = af.Fetch(ctx, srv.URL)
230 if !assert.NoError(t, err, `af.Fetch (#3) should succeed`) {
231 return
232 }
233
234 if !checkAccessCount(t, ctx, ks, 4, 5) {
235 return
236 }
237 })
238 }
239
240 func TestRefreshSnapshot(t *testing.T) {
241 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
242 defer cancel()
243
244 var jwksURLs []string
245 getJwksURL := func(dst *[]string, url string) bool {
246 req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
247 if err != nil {
248 return false
249 }
250
251 res, err := http.DefaultClient.Do(req)
252 if err != nil {
253 return false
254 }
255 defer res.Body.Close()
256
257 var m map[string]interface{}
258 if err := json.NewDecoder(res.Body).Decode(&m); err != nil {
259 return false
260 }
261
262 jwksURL, ok := m["jwks_uri"]
263 if !ok {
264 return false
265 }
266 *dst = append(*dst, jwksURL.(string))
267 return true
268 }
269 if !getJwksURL(&jwksURLs, "https://oidc-sample.onelogin.com/oidc/2/.well-known/openid-configuration") {
270 t.SkipNow()
271 }
272 if !getJwksURL(&jwksURLs, "https://accounts.google.com/.well-known/openid-configuration") {
273 t.SkipNow()
274 }
275
276 ar := jwk.NewAutoRefresh(ctx)
277 for _, url := range jwksURLs {
278 ar.Configure(url)
279 }
280
281 for _, url := range jwksURLs {
282 _, _ = ar.Refresh(ctx, url)
283 }
284
285 for target := range ar.Snapshot() {
286 t.Logf("%s last refreshed at %s, next refresh at %s", target.URL, target.LastRefresh, target.NextRefresh)
287 }
288
289 for _, url := range jwksURLs {
290 ar.Remove(url)
291 }
292
293 if !assert.Len(t, ar.Snapshot(), 0, `there should be no URLs`) {
294 return
295 }
296
297 if !assert.Error(t, ar.Remove(`dummy`), `removing a non-existing url should be an error`) {
298 return
299 }
300 }
301
302 func TestErrorSink(t *testing.T) {
303 t.Parallel()
304
305 k, err := jwxtest.GenerateRsaJwk()
306 if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
307 return
308 }
309 set := jwk.NewSet()
310 set.Add(k)
311 testcases := []struct {
312 Name string
313 Options func() []jwk.AutoRefreshOption
314 Handler http.Handler
315 }{
316 {
317 Name: "non-200 response",
318 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
319 w.WriteHeader(http.StatusForbidden)
320 }),
321 },
322 {
323 Name: "invalid JWK",
324 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
325 w.WriteHeader(http.StatusOK)
326 w.Write([]byte(`{"empty": "nonthingness"}`))
327 }),
328 },
329 {
330 Name: `rejected by whitelist`,
331 Options: func() []jwk.AutoRefreshOption {
332 return []jwk.AutoRefreshOption{
333 jwk.WithFetchWhitelist(jwk.WhitelistFunc(func(_ string) bool {
334 return false
335 })),
336 }
337 },
338 Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
339 w.WriteHeader(http.StatusOK)
340 json.NewEncoder(w).Encode(k)
341 }),
342 },
343 }
344
345 for _, tc := range testcases {
346 tc := tc
347 t.Run(tc.Name, func(t *testing.T) {
348 t.Parallel()
349 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
350 defer cancel()
351 srv := httptest.NewServer(tc.Handler)
352 defer srv.Close()
353
354 ar := jwk.NewAutoRefresh(ctx)
355
356 var options []jwk.AutoRefreshOption
357 if f := tc.Options; f != nil {
358 options = f()
359 }
360 options = append(options, jwk.WithRefreshInterval(500*time.Millisecond))
361 ar.Configure(srv.URL, options...)
362 ch := make(chan jwk.AutoRefreshError, 256)
363 ar.ErrorSink(ch)
364 ar.Fetch(ctx, srv.URL)
365
366 timer := time.NewTimer(3 * time.Second)
367
368 select {
369 case <-ctx.Done():
370 t.Errorf(`ctx.Done before timer`)
371 case <-timer.C:
372 }
373
374 cancel()
375
376
377
378 l := len(ch)
379 if !assert.True(t, l <= 7, "number of errors shold be less than or equal to 7 (%d)", l) {
380 return
381 }
382 if !assert.True(t, l >= 5, "number of errors shold be greather than or equal to 5 (%d)", l) {
383 return
384 }
385 })
386 }
387 }
388
389 func TestAutoRefreshRace(t *testing.T) {
390 k, err := jwxtest.GenerateRsaJwk()
391 if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
392 return
393 }
394 set := jwk.NewSet()
395 set.Add(k)
396
397
398 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
399 w.WriteHeader(http.StatusOK)
400 json.NewEncoder(w).Encode(k)
401 }))
402 defer srv.Close()
403
404
405 ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
406 defer cancel()
407 ar := jwk.NewAutoRefresh(ctx)
408 ch := make(chan jwk.AutoRefreshError, 256)
409 ar.ErrorSink(ch)
410
411 wg := sync.WaitGroup{}
412 routineErr := make(chan error, 20)
413
414
415
416
417 for i := 0; i < 5000; i++ {
418 wg.Add(1)
419 go func() {
420 defer wg.Done()
421 ctx := context.Background()
422
423 ar.Configure(srv.URL, jwk.WithRefreshInterval(500*time.Millisecond))
424 _, err := ar.Refresh(ctx, srv.URL)
425
426 if err != nil {
427 routineErr <- err
428 }
429 }()
430 }
431 wg.Wait()
432
433 require.Len(t, routineErr, 0)
434 }
435
View as plain text