1 package bearer
2
3 import (
4 "context"
5 "fmt"
6 "strconv"
7 "strings"
8 "sync"
9 "sync/atomic"
10 "testing"
11 "time"
12 )
13
14 var _ TokenProvider = (*TokenCache)(nil)
15
16 func TestTokenCache_cache(t *testing.T) {
17 expectToken := Token{
18 Value: "abc123",
19 }
20
21 var retrieveCalled bool
22 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
23 if retrieveCalled {
24 t.Fatalf("expect wrapped provider to be called once")
25 }
26 retrieveCalled = true
27 return expectToken, nil
28 }))
29
30 token, err := provider.RetrieveBearerToken(context.Background())
31 if err != nil {
32 t.Fatalf("expect no error, got %v", err)
33 }
34 if expectToken != token {
35 t.Errorf("expect token match: %v != %v", expectToken, token)
36 }
37
38 for i := 0; i < 100; i++ {
39 token, err := provider.RetrieveBearerToken(context.Background())
40 if err != nil {
41 t.Fatalf("expect no error, got %v", err)
42 }
43 if expectToken != token {
44 t.Errorf("expect token match: %v != %v", expectToken, token)
45 }
46 }
47 }
48
49 func TestTokenCache_cacheConcurrent(t *testing.T) {
50 expectToken := Token{
51 Value: "abc123",
52 }
53
54 var retrieveCalled bool
55 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
56 if retrieveCalled {
57 t.Fatalf("expect wrapped provider to be called once")
58 }
59 retrieveCalled = true
60 return expectToken, nil
61 }))
62
63 token, err := provider.RetrieveBearerToken(context.Background())
64 if err != nil {
65 t.Fatalf("expect no error, got %v", err)
66 }
67 if expectToken != token {
68 t.Errorf("expect token match: %v != %v", expectToken, token)
69 }
70
71 for i := 0; i < 100; i++ {
72 t.Run(strconv.Itoa(i), func(t *testing.T) {
73 t.Parallel()
74
75 token, err := provider.RetrieveBearerToken(context.Background())
76 if err != nil {
77 t.Fatalf("expect no error, got %v", err)
78 }
79 if expectToken != token {
80 t.Errorf("expect token match: %v != %v", expectToken, token)
81 }
82 })
83 }
84 }
85
86 func TestTokenCache_expired(t *testing.T) {
87 origTimeNow := timeNow
88 defer func() { timeNow = origTimeNow }()
89
90 timeNow = func() time.Time { return time.Time{} }
91
92 expectToken := Token{
93 Value: "abc123",
94 CanExpire: true,
95 Expires: timeNow().Add(10 * time.Minute),
96 }
97 refreshedToken := Token{
98 Value: "refreshed-abc123",
99 CanExpire: true,
100 Expires: timeNow().Add(30 * time.Minute),
101 }
102
103 retrievedCount := new(int32)
104 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
105 if atomic.AddInt32(retrievedCount, 1) > 1 {
106 return refreshedToken, nil
107 }
108 return expectToken, nil
109 }))
110
111 for i := 0; i < 10; i++ {
112 token, err := provider.RetrieveBearerToken(context.Background())
113 if err != nil {
114 t.Fatalf("expect no error, got %v", err)
115 }
116 if expectToken != token {
117 t.Errorf("expect token match: %v != %v", expectToken, token)
118 }
119 }
120 if e, a := 1, int(atomic.LoadInt32(retrievedCount)); e != a {
121 t.Errorf("expect %v provider calls, got %v", e, a)
122 }
123
124
125 timeNow = func() time.Time {
126 return (time.Time{}).Add(10 * time.Minute)
127 }
128
129 token, err := provider.RetrieveBearerToken(context.Background())
130 if err != nil {
131 t.Fatalf("expect no error, got %v", err)
132 }
133 if refreshedToken != token {
134 t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
135 }
136 if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
137 t.Errorf("expect %v provider calls, got %v", e, a)
138 }
139 }
140
141 func TestTokenCache_cancelled(t *testing.T) {
142 providerRunning := make(chan struct{})
143 providerDone := make(chan struct{})
144 var onceClose sync.Once
145 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
146 onceClose.Do(func() { close(providerRunning) })
147
148
149
150
151 select {
152 case <-providerDone:
153 return Token{Value: "abc123"}, nil
154 case <-ctx.Done():
155 return Token{}, fmt.Errorf("unexpected context canceled, %w", ctx.Err())
156 }
157 }))
158
159 ctx, cancel := context.WithCancel(context.Background())
160 cancel()
161
162
163
164 var wg sync.WaitGroup
165 wg.Add(1)
166 go func() {
167 defer wg.Done()
168
169 _, err := provider.RetrieveBearerToken(ctx)
170 if err == nil {
171 t.Errorf("expect error, got none")
172
173 } else if e, a := "unexpected context canceled", err.Error(); strings.Contains(a, e) {
174 t.Errorf("unexpected context canceled received, %v", err)
175
176 } else if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
177 t.Errorf("expect %v error in, %v", e, a)
178 }
179 }()
180
181 <-providerRunning
182
183
184
185 wg.Add(1)
186 go func() {
187 defer wg.Done()
188
189 token, err := provider.RetrieveBearerToken(context.Background())
190 if err != nil {
191 t.Errorf("expect no error, got %v", err)
192 } else {
193 expect := Token{Value: "abc123"}
194 if expect != token {
195 t.Errorf("expect token retrieve match: %v != %v", expect, token)
196 }
197 }
198 }()
199 close(providerDone)
200
201 wg.Wait()
202 }
203
204 func TestTokenCache_cancelledWithTimeout(t *testing.T) {
205 providerReady := make(chan struct{})
206 var providerReadCloseOnce sync.Once
207 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
208 providerReadCloseOnce.Do(func() { close(providerReady) })
209
210 <-ctx.Done()
211 return Token{}, fmt.Errorf("token retrieve timeout, %w", ctx.Err())
212 }), func(o *TokenCacheOptions) {
213 o.RetrieveBearerTokenTimeout = time.Millisecond
214 })
215
216 var wg sync.WaitGroup
217
218
219
220 for i := 0; i < 5; i++ {
221 wg.Add(1)
222 go func() {
223 defer wg.Done()
224 <-providerReady
225
226 _, err := provider.RetrieveBearerToken(context.Background())
227 if err == nil {
228 t.Errorf("expect error, got none")
229
230 } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
231 t.Errorf("expect %v error in, %v", e, a)
232 }
233 }()
234 }
235
236 _, err := provider.RetrieveBearerToken(context.Background())
237 if err == nil {
238 t.Errorf("expect error, got none")
239
240 } else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
241 t.Errorf("expect %v error in, %v", e, a)
242 }
243
244 wg.Wait()
245 }
246
247 func TestTokenCache_asyncRefresh(t *testing.T) {
248 origTimeNow := timeNow
249 defer func() { timeNow = origTimeNow }()
250
251 timeNow = func() time.Time { return time.Time{} }
252
253 expectToken := Token{
254 Value: "abc123",
255 CanExpire: true,
256 Expires: timeNow().Add(10 * time.Minute),
257 }
258 refreshedToken := Token{
259 Value: "refreshed-abc123",
260 CanExpire: true,
261 Expires: timeNow().Add(30 * time.Minute),
262 }
263
264 retrievedCount := new(int32)
265 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
266 c := atomic.AddInt32(retrievedCount, 1)
267 switch {
268 case c == 1:
269 return expectToken, nil
270 case c > 1 && c < 5:
271 return Token{}, fmt.Errorf("some error")
272 case c == 5:
273 return refreshedToken, nil
274 default:
275 return Token{}, fmt.Errorf("unexpected error")
276 }
277 }), func(o *TokenCacheOptions) {
278 o.RefreshBeforeExpires = 5 * time.Minute
279 })
280
281
282 token, err := provider.RetrieveBearerToken(context.Background())
283 if err != nil {
284 t.Fatalf("expect no error, got %v", err)
285 }
286 if expectToken != token {
287 t.Errorf("expect token match: %v != %v", expectToken, token)
288 }
289
290
291
292 timeNow = func() time.Time {
293 return (time.Time{}).Add(6 * time.Minute)
294 }
295
296 for i := 0; i < 4; i++ {
297 token, err := provider.RetrieveBearerToken(context.Background())
298 if err != nil {
299 t.Fatalf("expect no error, got %v", err)
300 }
301 if expectToken != token {
302 t.Errorf("expect token match: %v != %v", expectToken, token)
303 }
304 }
305
306 testWaitAsyncRefreshDone(provider)
307
308 if c := int(atomic.LoadInt32(retrievedCount)); c < 2 || c > 5 {
309 t.Fatalf("expect async refresh to be called [2,5) times, got, %v", c)
310 }
311
312
313 if c := atomic.LoadInt32(retrievedCount); c != 5 {
314 atomic.StoreInt32(retrievedCount, 4)
315 token, err := provider.RetrieveBearerToken(context.Background())
316 if err != nil {
317 t.Fatalf("expect no error, got %v", err)
318 }
319 if expectToken != token {
320 t.Errorf("expect token match: %v != %v", expectToken, token)
321 }
322 testWaitAsyncRefreshDone(provider)
323 }
324
325
326
327 token, err = provider.RetrieveBearerToken(context.Background())
328 if err != nil {
329 t.Fatalf("expect no error, got %v", err)
330 }
331 if refreshedToken != token {
332 t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
333 }
334 }
335
336 func TestTokenCache_asyncRefreshWithMinDelay(t *testing.T) {
337 origTimeNow := timeNow
338 defer func() { timeNow = origTimeNow }()
339
340 timeNow = func() time.Time { return time.Time{} }
341
342 expectToken := Token{
343 Value: "abc123",
344 CanExpire: true,
345 Expires: timeNow().Add(10 * time.Minute),
346 }
347 refreshedToken := Token{
348 Value: "refreshed-abc123",
349 CanExpire: true,
350 Expires: timeNow().Add(30 * time.Minute),
351 }
352
353 retrievedCount := new(int32)
354 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
355 c := atomic.AddInt32(retrievedCount, 1)
356 switch {
357 case c == 1:
358 return expectToken, nil
359 case c > 1 && c < 5:
360 return Token{}, fmt.Errorf("some error")
361 case c == 5:
362 return refreshedToken, nil
363 default:
364 return Token{}, fmt.Errorf("unexpected error")
365 }
366 }), func(o *TokenCacheOptions) {
367 o.RefreshBeforeExpires = 5 * time.Minute
368 o.AsyncRefreshMinimumDelay = 30 * time.Second
369 })
370
371
372 token, err := provider.RetrieveBearerToken(context.Background())
373 if err != nil {
374 t.Fatalf("expect no error, got %v", err)
375 }
376 if expectToken != token {
377 t.Errorf("expect token match: %v != %v", expectToken, token)
378 }
379
380
381
382 timeNow = func() time.Time {
383 return (time.Time{}).Add(6 * time.Minute)
384 }
385
386 for i := 0; i < 4; i++ {
387 token, err := provider.RetrieveBearerToken(context.Background())
388 if err != nil {
389 t.Fatalf("expect no error, got %v", err)
390 }
391 if expectToken != token {
392 t.Errorf("expect token match: %v != %v", expectToken, token)
393 }
394
395 testWaitAsyncRefreshDone(provider)
396 }
397
398
399 if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
400 t.Fatalf("expect %v min async refresh, got %v", e, a)
401 }
402
403
404 timeNow = func() time.Time { return (time.Time{}).Add(7 * time.Minute) }
405
406 atomic.StoreInt32(retrievedCount, 4)
407
408
409 token, err = provider.RetrieveBearerToken(context.Background())
410 if err != nil {
411 t.Fatalf("expect no error, got %v", err)
412 }
413 if expectToken != token {
414 t.Errorf("expect token match: %v != %v", expectToken, token)
415 }
416
417 testWaitAsyncRefreshDone(provider)
418
419
420
421 token, err = provider.RetrieveBearerToken(context.Background())
422 if err != nil {
423 t.Fatalf("expect no error, got %v", err)
424 }
425 if refreshedToken != token {
426 t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
427 }
428 }
429
430 func TestTokenCache_disableAsyncRefresh(t *testing.T) {
431 origTimeNow := timeNow
432 defer func() { timeNow = origTimeNow }()
433
434 timeNow = func() time.Time { return time.Time{} }
435
436 expectToken := Token{
437 Value: "abc123",
438 CanExpire: true,
439 Expires: timeNow().Add(10 * time.Minute),
440 }
441 refreshedToken := Token{
442 Value: "refreshed-abc123",
443 CanExpire: true,
444 Expires: timeNow().Add(30 * time.Minute),
445 }
446
447 retrievedCount := new(int32)
448 provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
449 c := atomic.AddInt32(retrievedCount, 1)
450 switch {
451 case c == 1:
452 return expectToken, nil
453 case c > 1 && c < 5:
454 return Token{}, fmt.Errorf("some error")
455 case c == 5:
456 return refreshedToken, nil
457 default:
458 return Token{}, fmt.Errorf("unexpected error")
459 }
460 }), func(o *TokenCacheOptions) {
461 o.RefreshBeforeExpires = 5 * time.Minute
462 o.DisableAsyncRefresh = true
463 })
464
465
466 token, err := provider.RetrieveBearerToken(context.Background())
467 if err != nil {
468 t.Fatalf("expect no error, got %v", err)
469 }
470 if expectToken != token {
471 t.Errorf("expect token match: %v != %v", expectToken, token)
472 }
473
474
475 timeNow = func() time.Time {
476 return (time.Time{}).Add(6 * time.Minute)
477 }
478
479 for i := 0; i < 3; i++ {
480 _, err = provider.RetrieveBearerToken(context.Background())
481 if err == nil {
482 t.Fatalf("expect error, got none")
483 }
484 if e, a := "some error", err.Error(); !strings.Contains(a, e) {
485 t.Fatalf("expect %v error in %v", e, a)
486 }
487 if e, a := i+2, int(atomic.LoadInt32(retrievedCount)); e != a {
488 t.Fatalf("expect %v retrieveCount, got %v", e, a)
489 }
490 }
491 if e, a := 4, int(atomic.LoadInt32(retrievedCount)); e != a {
492 t.Fatalf("expect %v retrieveCount, got %v", e, a)
493 }
494
495
496
497 token, err = provider.RetrieveBearerToken(context.Background())
498 if err != nil {
499 t.Fatalf("expect no error, got %v", err)
500 }
501 if refreshedToken != token {
502 t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
503 }
504 }
505
506 func testWaitAsyncRefreshDone(provider *TokenCache) {
507 asyncResCh := provider.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
508 return nil, nil
509 })
510 <-asyncResCh
511 }
512
View as plain text