...

Source file src/github.com/aws/smithy-go/auth/bearer/token_cache_test.go

Documentation: github.com/aws/smithy-go/auth/bearer

     1  package bearer
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  	"sync/atomic"
    10  	"testing"
    11  	"time"
    12  )
    13  
    14  var _ TokenProvider = (*TokenCache)(nil)
    15  
    16  func TestTokenCache_cache(t *testing.T) {
    17  	expectToken := Token{
    18  		Value: "abc123",
    19  	}
    20  
    21  	var retrieveCalled bool
    22  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
    23  		if retrieveCalled {
    24  			t.Fatalf("expect wrapped provider to be called once")
    25  		}
    26  		retrieveCalled = true
    27  		return expectToken, nil
    28  	}))
    29  
    30  	token, err := provider.RetrieveBearerToken(context.Background())
    31  	if err != nil {
    32  		t.Fatalf("expect no error, got %v", err)
    33  	}
    34  	if expectToken != token {
    35  		t.Errorf("expect token match: %v != %v", expectToken, token)
    36  	}
    37  
    38  	for i := 0; i < 100; i++ {
    39  		token, err := provider.RetrieveBearerToken(context.Background())
    40  		if err != nil {
    41  			t.Fatalf("expect no error, got %v", err)
    42  		}
    43  		if expectToken != token {
    44  			t.Errorf("expect token match: %v != %v", expectToken, token)
    45  		}
    46  	}
    47  }
    48  
    49  func TestTokenCache_cacheConcurrent(t *testing.T) {
    50  	expectToken := Token{
    51  		Value: "abc123",
    52  	}
    53  
    54  	var retrieveCalled bool
    55  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
    56  		if retrieveCalled {
    57  			t.Fatalf("expect wrapped provider to be called once")
    58  		}
    59  		retrieveCalled = true
    60  		return expectToken, nil
    61  	}))
    62  
    63  	token, err := provider.RetrieveBearerToken(context.Background())
    64  	if err != nil {
    65  		t.Fatalf("expect no error, got %v", err)
    66  	}
    67  	if expectToken != token {
    68  		t.Errorf("expect token match: %v != %v", expectToken, token)
    69  	}
    70  
    71  	for i := 0; i < 100; i++ {
    72  		t.Run(strconv.Itoa(i), func(t *testing.T) {
    73  			t.Parallel()
    74  
    75  			token, err := provider.RetrieveBearerToken(context.Background())
    76  			if err != nil {
    77  				t.Fatalf("expect no error, got %v", err)
    78  			}
    79  			if expectToken != token {
    80  				t.Errorf("expect token match: %v != %v", expectToken, token)
    81  			}
    82  		})
    83  	}
    84  }
    85  
    86  func TestTokenCache_expired(t *testing.T) {
    87  	origTimeNow := timeNow
    88  	defer func() { timeNow = origTimeNow }()
    89  
    90  	timeNow = func() time.Time { return time.Time{} }
    91  
    92  	expectToken := Token{
    93  		Value:     "abc123",
    94  		CanExpire: true,
    95  		Expires:   timeNow().Add(10 * time.Minute),
    96  	}
    97  	refreshedToken := Token{
    98  		Value:     "refreshed-abc123",
    99  		CanExpire: true,
   100  		Expires:   timeNow().Add(30 * time.Minute),
   101  	}
   102  
   103  	retrievedCount := new(int32)
   104  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
   105  		if atomic.AddInt32(retrievedCount, 1) > 1 {
   106  			return refreshedToken, nil
   107  		}
   108  		return expectToken, nil
   109  	}))
   110  
   111  	for i := 0; i < 10; i++ {
   112  		token, err := provider.RetrieveBearerToken(context.Background())
   113  		if err != nil {
   114  			t.Fatalf("expect no error, got %v", err)
   115  		}
   116  		if expectToken != token {
   117  			t.Errorf("expect token match: %v != %v", expectToken, token)
   118  		}
   119  	}
   120  	if e, a := 1, int(atomic.LoadInt32(retrievedCount)); e != a {
   121  		t.Errorf("expect %v provider calls, got %v", e, a)
   122  	}
   123  
   124  	// Offset time for refresh
   125  	timeNow = func() time.Time {
   126  		return (time.Time{}).Add(10 * time.Minute)
   127  	}
   128  
   129  	token, err := provider.RetrieveBearerToken(context.Background())
   130  	if err != nil {
   131  		t.Fatalf("expect no error, got %v", err)
   132  	}
   133  	if refreshedToken != token {
   134  		t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
   135  	}
   136  	if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
   137  		t.Errorf("expect %v provider calls, got %v", e, a)
   138  	}
   139  }
   140  
   141  func TestTokenCache_cancelled(t *testing.T) {
   142  	providerRunning := make(chan struct{})
   143  	providerDone := make(chan struct{})
   144  	var onceClose sync.Once
   145  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
   146  		onceClose.Do(func() { close(providerRunning) })
   147  
   148  		// Provider running never receives context cancel so that if the first
   149  		// retrieve call is canceled all subsequent retrieve callers won't get
   150  		// canceled as well.
   151  		select {
   152  		case <-providerDone:
   153  			return Token{Value: "abc123"}, nil
   154  		case <-ctx.Done():
   155  			return Token{}, fmt.Errorf("unexpected context canceled, %w", ctx.Err())
   156  		}
   157  	}))
   158  
   159  	ctx, cancel := context.WithCancel(context.Background())
   160  	cancel()
   161  
   162  	// Retrieve that will have its context canceled, should return error, but
   163  	// underlying provider retrieve will continue to block in the background.
   164  	var wg sync.WaitGroup
   165  	wg.Add(1)
   166  	go func() {
   167  		defer wg.Done()
   168  
   169  		_, err := provider.RetrieveBearerToken(ctx)
   170  		if err == nil {
   171  			t.Errorf("expect error, got none")
   172  
   173  		} else if e, a := "unexpected context canceled", err.Error(); strings.Contains(a, e) {
   174  			t.Errorf("unexpected context canceled received, %v", err)
   175  
   176  		} else if e, a := "context canceled", err.Error(); !strings.Contains(a, e) {
   177  			t.Errorf("expect %v error in, %v", e, a)
   178  		}
   179  	}()
   180  
   181  	<-providerRunning
   182  
   183  	// Retrieve that will be added to existing single flight group, (or create
   184  	// a new group). Returning valid token.
   185  	wg.Add(1)
   186  	go func() {
   187  		defer wg.Done()
   188  
   189  		token, err := provider.RetrieveBearerToken(context.Background())
   190  		if err != nil {
   191  			t.Errorf("expect no error, got %v", err)
   192  		} else {
   193  			expect := Token{Value: "abc123"}
   194  			if expect != token {
   195  				t.Errorf("expect token retrieve match: %v != %v", expect, token)
   196  			}
   197  		}
   198  	}()
   199  	close(providerDone)
   200  
   201  	wg.Wait()
   202  }
   203  
   204  func TestTokenCache_cancelledWithTimeout(t *testing.T) {
   205  	providerReady := make(chan struct{})
   206  	var providerReadCloseOnce sync.Once
   207  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
   208  		providerReadCloseOnce.Do(func() { close(providerReady) })
   209  
   210  		<-ctx.Done()
   211  		return Token{}, fmt.Errorf("token retrieve timeout, %w", ctx.Err())
   212  	}), func(o *TokenCacheOptions) {
   213  		o.RetrieveBearerTokenTimeout = time.Millisecond
   214  	})
   215  
   216  	var wg sync.WaitGroup
   217  
   218  	// Spin up additional retrieves that will be deduplicated and block on the
   219  	// original retrieve call.
   220  	for i := 0; i < 5; i++ {
   221  		wg.Add(1)
   222  		go func() {
   223  			defer wg.Done()
   224  			<-providerReady
   225  
   226  			_, err := provider.RetrieveBearerToken(context.Background())
   227  			if err == nil {
   228  				t.Errorf("expect error, got none")
   229  
   230  			} else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
   231  				t.Errorf("expect %v error in, %v", e, a)
   232  			}
   233  		}()
   234  	}
   235  
   236  	_, err := provider.RetrieveBearerToken(context.Background())
   237  	if err == nil {
   238  		t.Errorf("expect error, got none")
   239  
   240  	} else if e, a := "token retrieve timeout", err.Error(); !strings.Contains(a, e) {
   241  		t.Errorf("expect %v error in, %v", e, a)
   242  	}
   243  
   244  	wg.Wait()
   245  }
   246  
   247  func TestTokenCache_asyncRefresh(t *testing.T) {
   248  	origTimeNow := timeNow
   249  	defer func() { timeNow = origTimeNow }()
   250  
   251  	timeNow = func() time.Time { return time.Time{} }
   252  
   253  	expectToken := Token{
   254  		Value:     "abc123",
   255  		CanExpire: true,
   256  		Expires:   timeNow().Add(10 * time.Minute),
   257  	}
   258  	refreshedToken := Token{
   259  		Value:     "refreshed-abc123",
   260  		CanExpire: true,
   261  		Expires:   timeNow().Add(30 * time.Minute),
   262  	}
   263  
   264  	retrievedCount := new(int32)
   265  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
   266  		c := atomic.AddInt32(retrievedCount, 1)
   267  		switch {
   268  		case c == 1:
   269  			return expectToken, nil
   270  		case c > 1 && c < 5:
   271  			return Token{}, fmt.Errorf("some error")
   272  		case c == 5:
   273  			return refreshedToken, nil
   274  		default:
   275  			return Token{}, fmt.Errorf("unexpected error")
   276  		}
   277  	}), func(o *TokenCacheOptions) {
   278  		o.RefreshBeforeExpires = 5 * time.Minute
   279  	})
   280  
   281  	// 1: Initial retrieve to cache token
   282  	token, err := provider.RetrieveBearerToken(context.Background())
   283  	if err != nil {
   284  		t.Fatalf("expect no error, got %v", err)
   285  	}
   286  	if expectToken != token {
   287  		t.Errorf("expect token match: %v != %v", expectToken, token)
   288  	}
   289  
   290  	// 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous
   291  	// refreshes.
   292  	timeNow = func() time.Time {
   293  		return (time.Time{}).Add(6 * time.Minute)
   294  	}
   295  
   296  	for i := 0; i < 4; i++ {
   297  		token, err := provider.RetrieveBearerToken(context.Background())
   298  		if err != nil {
   299  			t.Fatalf("expect no error, got %v", err)
   300  		}
   301  		if expectToken != token {
   302  			t.Errorf("expect token match: %v != %v", expectToken, token)
   303  		}
   304  	}
   305  	// Wait for all async refreshes to complete
   306  	testWaitAsyncRefreshDone(provider)
   307  
   308  	if c := int(atomic.LoadInt32(retrievedCount)); c < 2 || c > 5 {
   309  		t.Fatalf("expect async refresh to be called [2,5) times, got, %v", c)
   310  	}
   311  
   312  	// Ensure enough retrieves have been done to trigger refresh.
   313  	if c := atomic.LoadInt32(retrievedCount); c != 5 {
   314  		atomic.StoreInt32(retrievedCount, 4)
   315  		token, err := provider.RetrieveBearerToken(context.Background())
   316  		if err != nil {
   317  			t.Fatalf("expect no error, got %v", err)
   318  		}
   319  		if expectToken != token {
   320  			t.Errorf("expect token match: %v != %v", expectToken, token)
   321  		}
   322  		testWaitAsyncRefreshDone(provider)
   323  	}
   324  
   325  	// Last async refresh will succeed and update cached token, expect the next
   326  	// call to get refreshed token.
   327  	token, err = provider.RetrieveBearerToken(context.Background())
   328  	if err != nil {
   329  		t.Fatalf("expect no error, got %v", err)
   330  	}
   331  	if refreshedToken != token {
   332  		t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
   333  	}
   334  }
   335  
   336  func TestTokenCache_asyncRefreshWithMinDelay(t *testing.T) {
   337  	origTimeNow := timeNow
   338  	defer func() { timeNow = origTimeNow }()
   339  
   340  	timeNow = func() time.Time { return time.Time{} }
   341  
   342  	expectToken := Token{
   343  		Value:     "abc123",
   344  		CanExpire: true,
   345  		Expires:   timeNow().Add(10 * time.Minute),
   346  	}
   347  	refreshedToken := Token{
   348  		Value:     "refreshed-abc123",
   349  		CanExpire: true,
   350  		Expires:   timeNow().Add(30 * time.Minute),
   351  	}
   352  
   353  	retrievedCount := new(int32)
   354  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
   355  		c := atomic.AddInt32(retrievedCount, 1)
   356  		switch {
   357  		case c == 1:
   358  			return expectToken, nil
   359  		case c > 1 && c < 5:
   360  			return Token{}, fmt.Errorf("some error")
   361  		case c == 5:
   362  			return refreshedToken, nil
   363  		default:
   364  			return Token{}, fmt.Errorf("unexpected error")
   365  		}
   366  	}), func(o *TokenCacheOptions) {
   367  		o.RefreshBeforeExpires = 5 * time.Minute
   368  		o.AsyncRefreshMinimumDelay = 30 * time.Second
   369  	})
   370  
   371  	// 1: Initial retrieve to cache token
   372  	token, err := provider.RetrieveBearerToken(context.Background())
   373  	if err != nil {
   374  		t.Fatalf("expect no error, got %v", err)
   375  	}
   376  	if expectToken != token {
   377  		t.Errorf("expect token match: %v != %v", expectToken, token)
   378  	}
   379  
   380  	// 2-5: Offset time for subsequent calls to retrieve to trigger asynchronous
   381  	// refreshes.
   382  	timeNow = func() time.Time {
   383  		return (time.Time{}).Add(6 * time.Minute)
   384  	}
   385  
   386  	for i := 0; i < 4; i++ {
   387  		token, err := provider.RetrieveBearerToken(context.Background())
   388  		if err != nil {
   389  			t.Fatalf("expect no error, got %v", err)
   390  		}
   391  		if expectToken != token {
   392  			t.Errorf("expect token match: %v != %v", expectToken, token)
   393  		}
   394  		// Wait for all async refreshes to complete ensure not deduped
   395  		testWaitAsyncRefreshDone(provider)
   396  	}
   397  
   398  	// Only a single refresh attempt is expected.
   399  	if e, a := 2, int(atomic.LoadInt32(retrievedCount)); e != a {
   400  		t.Fatalf("expect %v min async refresh, got %v", e, a)
   401  	}
   402  
   403  	// Move time forward to ensure another async refresh is triggered.
   404  	timeNow = func() time.Time { return (time.Time{}).Add(7 * time.Minute) }
   405  	// Make sure the next attempt refreshes the token
   406  	atomic.StoreInt32(retrievedCount, 4)
   407  
   408  	// Do async retrieve that will succeed refreshing in background.
   409  	token, err = provider.RetrieveBearerToken(context.Background())
   410  	if err != nil {
   411  		t.Fatalf("expect no error, got %v", err)
   412  	}
   413  	if expectToken != token {
   414  		t.Errorf("expect token match: %v != %v", expectToken, token)
   415  	}
   416  	// Wait for all async refreshes to complete ensure not deduped
   417  	testWaitAsyncRefreshDone(provider)
   418  
   419  	// Last async refresh will succeed and update cached token, expect the next
   420  	// call to get refreshed token.
   421  	token, err = provider.RetrieveBearerToken(context.Background())
   422  	if err != nil {
   423  		t.Fatalf("expect no error, got %v", err)
   424  	}
   425  	if refreshedToken != token {
   426  		t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
   427  	}
   428  }
   429  
   430  func TestTokenCache_disableAsyncRefresh(t *testing.T) {
   431  	origTimeNow := timeNow
   432  	defer func() { timeNow = origTimeNow }()
   433  
   434  	timeNow = func() time.Time { return time.Time{} }
   435  
   436  	expectToken := Token{
   437  		Value:     "abc123",
   438  		CanExpire: true,
   439  		Expires:   timeNow().Add(10 * time.Minute),
   440  	}
   441  	refreshedToken := Token{
   442  		Value:     "refreshed-abc123",
   443  		CanExpire: true,
   444  		Expires:   timeNow().Add(30 * time.Minute),
   445  	}
   446  
   447  	retrievedCount := new(int32)
   448  	provider := NewTokenCache(TokenProviderFunc(func(ctx context.Context) (Token, error) {
   449  		c := atomic.AddInt32(retrievedCount, 1)
   450  		switch {
   451  		case c == 1:
   452  			return expectToken, nil
   453  		case c > 1 && c < 5:
   454  			return Token{}, fmt.Errorf("some error")
   455  		case c == 5:
   456  			return refreshedToken, nil
   457  		default:
   458  			return Token{}, fmt.Errorf("unexpected error")
   459  		}
   460  	}), func(o *TokenCacheOptions) {
   461  		o.RefreshBeforeExpires = 5 * time.Minute
   462  		o.DisableAsyncRefresh = true
   463  	})
   464  
   465  	// 1: Initial retrieve to cache token
   466  	token, err := provider.RetrieveBearerToken(context.Background())
   467  	if err != nil {
   468  		t.Fatalf("expect no error, got %v", err)
   469  	}
   470  	if expectToken != token {
   471  		t.Errorf("expect token match: %v != %v", expectToken, token)
   472  	}
   473  
   474  	// Update time into refresh window before token expires
   475  	timeNow = func() time.Time {
   476  		return (time.Time{}).Add(6 * time.Minute)
   477  	}
   478  
   479  	for i := 0; i < 3; i++ {
   480  		_, err = provider.RetrieveBearerToken(context.Background())
   481  		if err == nil {
   482  			t.Fatalf("expect error, got none")
   483  		}
   484  		if e, a := "some error", err.Error(); !strings.Contains(a, e) {
   485  			t.Fatalf("expect %v error in %v", e, a)
   486  		}
   487  		if e, a := i+2, int(atomic.LoadInt32(retrievedCount)); e != a {
   488  			t.Fatalf("expect %v retrieveCount, got %v", e, a)
   489  		}
   490  	}
   491  	if e, a := 4, int(atomic.LoadInt32(retrievedCount)); e != a {
   492  		t.Fatalf("expect %v retrieveCount, got %v", e, a)
   493  	}
   494  
   495  	// Last refresh will succeed and update cached token, expect the next
   496  	// call to get refreshed token.
   497  	token, err = provider.RetrieveBearerToken(context.Background())
   498  	if err != nil {
   499  		t.Fatalf("expect no error, got %v", err)
   500  	}
   501  	if refreshedToken != token {
   502  		t.Errorf("expect refreshed token match: %v != %v", refreshedToken, token)
   503  	}
   504  }
   505  
   506  func testWaitAsyncRefreshDone(provider *TokenCache) {
   507  	asyncResCh := provider.sfGroup.DoChan("async-refresh", func() (interface{}, error) {
   508  		return nil, nil
   509  	})
   510  	<-asyncResCh
   511  }
   512  

View as plain text