...

Source file src/github.com/lestrrat-go/jwx/jwk/refresh_test.go

Documentation: github.com/lestrrat-go/jwx/jwk

     1  package jwk_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/lestrrat-go/backoff/v2"
    14  	"github.com/lestrrat-go/iter/arrayiter"
    15  	"github.com/lestrrat-go/jwx/internal/json"
    16  	"github.com/lestrrat-go/jwx/internal/jwxtest"
    17  	"github.com/lestrrat-go/jwx/jwk"
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  )
    21  
    22  //nolint:revive,golint
    23  func checkAccessCount(t *testing.T, ctx context.Context, src arrayiter.Source, expected ...int) bool {
    24  	t.Helper()
    25  
    26  	iter := src.Iterate(ctx)
    27  	iter.Next(ctx)
    28  
    29  	key := iter.Pair().Value.(jwk.Key)
    30  	v, ok := key.Get(`accessCount`)
    31  	if !assert.True(t, ok, `key.Get("accessCount") should succeed`) {
    32  		return false
    33  	}
    34  
    35  	for _, e := range expected {
    36  		if v == float64(e) {
    37  			return assert.Equal(t, float64(e), v, `key.Get("accessCount") should be %d`, e)
    38  		}
    39  	}
    40  
    41  	var buf bytes.Buffer
    42  	fmt.Fprint(&buf, "[")
    43  	for i, e := range expected {
    44  		fmt.Fprintf(&buf, "%d", e)
    45  		if i < len(expected)-1 {
    46  			fmt.Fprint(&buf, ", ")
    47  		}
    48  	}
    49  	fmt.Fprintf(&buf, "]")
    50  	return assert.Failf(t, `key.Get("accessCount") should be one of %s (got %d)`, buf.String(), v)
    51  }
    52  
    53  func TestAutoRefresh(t *testing.T) {
    54  	t.Parallel()
    55  
    56  	t.Run("Specify explicit refresh interval", func(t *testing.T) {
    57  		t.Parallel()
    58  		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
    59  		defer cancel()
    60  
    61  		var accessCount int
    62  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    63  			accessCount++
    64  
    65  			key := map[string]interface{}{
    66  				"kty":         "EC",
    67  				"crv":         "P-256",
    68  				"x":           "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
    69  				"y":           "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
    70  				"accessCount": accessCount,
    71  			}
    72  			hdrs := w.Header()
    73  			hdrs.Set(`Content-Type`, `application/json`)
    74  			hdrs.Set(`Cache-Control`, `max-age=7200`) // Make sure this is ignored
    75  
    76  			json.NewEncoder(w).Encode(key)
    77  		}))
    78  		defer srv.Close()
    79  
    80  		af := jwk.NewAutoRefresh(ctx)
    81  		af.Configure(srv.URL, jwk.WithRefreshInterval(3*time.Second))
    82  
    83  		retries := 5
    84  
    85  		var wg sync.WaitGroup
    86  		wg.Add(retries)
    87  		for i := 0; i < retries; i++ {
    88  			// Run these in separate goroutines to emulate a possible thundering herd
    89  			go func() {
    90  				defer wg.Done()
    91  				ks, err := af.Fetch(ctx, srv.URL)
    92  				if !assert.NoError(t, err, `af.Fetch should succeed`) {
    93  					return
    94  				}
    95  				if !checkAccessCount(t, ctx, ks, 1) {
    96  					return
    97  				}
    98  			}()
    99  		}
   100  
   101  		t.Logf("Waiting for fetching goroutines...")
   102  		wg.Wait()
   103  		t.Logf("Waiting for the refresh ...")
   104  		time.Sleep(4 * time.Second)
   105  		ks, err := af.Fetch(ctx, srv.URL)
   106  		if !assert.NoError(t, err, `af.Fetch should succeed`) {
   107  			return
   108  		}
   109  		if !checkAccessCount(t, ctx, ks, 2) {
   110  			return
   111  		}
   112  	})
   113  	t.Run("Calculate next refresh from Cache-Control header", func(t *testing.T) {
   114  		t.Parallel()
   115  		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   116  		defer cancel()
   117  
   118  		var accessCount int
   119  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   120  			accessCount++
   121  
   122  			key := map[string]interface{}{
   123  				"kty":         "EC",
   124  				"crv":         "P-256",
   125  				"x":           "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
   126  				"y":           "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
   127  				"accessCount": accessCount,
   128  			}
   129  			hdrs := w.Header()
   130  			hdrs.Set(`Content-Type`, `application/json`)
   131  			hdrs.Set(`Cache-Control`, `max-age=3`)
   132  
   133  			json.NewEncoder(w).Encode(key)
   134  		}))
   135  		defer srv.Close()
   136  
   137  		af := jwk.NewAutoRefresh(ctx)
   138  		af.Configure(srv.URL, jwk.WithMinRefreshInterval(time.Second))
   139  		if !assert.True(t, af.IsRegistered(srv.URL), `af.IsRegistered should be true`) {
   140  			return
   141  		}
   142  
   143  		retries := 5
   144  
   145  		var wg sync.WaitGroup
   146  		wg.Add(retries)
   147  		for i := 0; i < retries; i++ {
   148  			// Run these in separate goroutines to emulate a possible thundering herd
   149  			go func() {
   150  				defer wg.Done()
   151  				ks, err := af.Fetch(ctx, srv.URL)
   152  				if !assert.NoError(t, err, `af.Fetch should succeed`) {
   153  					return
   154  				}
   155  
   156  				if !checkAccessCount(t, ctx, ks, 1) {
   157  					return
   158  				}
   159  			}()
   160  		}
   161  
   162  		t.Logf("Waiting for fetching goroutines...")
   163  		wg.Wait()
   164  		t.Logf("Waiting for the refresh ...")
   165  		time.Sleep(4 * time.Second)
   166  		ks, err := af.Fetch(ctx, srv.URL)
   167  		if !assert.NoError(t, err, `af.Fetch should succeed`) {
   168  			return
   169  		}
   170  		if !checkAccessCount(t, ctx, ks, 2) {
   171  			return
   172  		}
   173  	})
   174  	t.Run("Backoff", func(t *testing.T) {
   175  		t.Parallel()
   176  		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   177  		defer cancel()
   178  
   179  		var accessCount int
   180  		srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   181  			accessCount++
   182  			if accessCount > 1 && accessCount < 4 {
   183  				http.Error(w, "wait for it....", http.StatusForbidden)
   184  				return
   185  			}
   186  
   187  			key := map[string]interface{}{
   188  				"kty":         "EC",
   189  				"crv":         "P-256",
   190  				"x":           "SVqB4JcUD6lsfvqMr-OKUNUphdNn64Eay60978ZlL74",
   191  				"y":           "lf0u0pMj4lGAzZix5u4Cm5CMQIgMNpkwy163wtKYVKI",
   192  				"accessCount": accessCount,
   193  			}
   194  			hdrs := w.Header()
   195  			hdrs.Set(`Content-Type`, `application/json`)
   196  			hdrs.Set(`Cache-Control`, `max-age=1`)
   197  
   198  			json.NewEncoder(w).Encode(key)
   199  		}))
   200  		defer srv.Close()
   201  
   202  		af := jwk.NewAutoRefresh(ctx)
   203  		bo := backoff.Constant(backoff.WithInterval(time.Second))
   204  		af.Configure(srv.URL, jwk.WithFetchBackoff(bo), jwk.WithMinRefreshInterval(1))
   205  
   206  		// First fetch should succeed
   207  		ks, err := af.Fetch(ctx, srv.URL)
   208  		if !assert.NoError(t, err, `af.Fetch (#1) should succed`) {
   209  			return
   210  		}
   211  		if !checkAccessCount(t, ctx, ks, 1) {
   212  			return
   213  		}
   214  
   215  		// enough time for 1 refresh to have occurred
   216  		time.Sleep(1500 * time.Millisecond)
   217  		ks, err = af.Fetch(ctx, srv.URL)
   218  		if !assert.NoError(t, err, `af.Fetch (#2) should succeed`) {
   219  			return
   220  		}
   221  		// Should be using the cached version
   222  		if !checkAccessCount(t, ctx, ks, 1) {
   223  			return
   224  		}
   225  
   226  		// enough time for 2 refreshes to have occurred
   227  		time.Sleep(2500 * time.Millisecond)
   228  
   229  		ks, err = af.Fetch(ctx, srv.URL)
   230  		if !assert.NoError(t, err, `af.Fetch (#3) should succeed`) {
   231  			return
   232  		}
   233  		// should be new
   234  		if !checkAccessCount(t, ctx, ks, 4, 5) {
   235  			return
   236  		}
   237  	})
   238  }
   239  
   240  func TestRefreshSnapshot(t *testing.T) {
   241  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   242  	defer cancel()
   243  
   244  	var jwksURLs []string
   245  	getJwksURL := func(dst *[]string, url string) bool {
   246  		req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
   247  		if err != nil {
   248  			return false
   249  		}
   250  
   251  		res, err := http.DefaultClient.Do(req)
   252  		if err != nil {
   253  			return false
   254  		}
   255  		defer res.Body.Close()
   256  
   257  		var m map[string]interface{}
   258  		if err := json.NewDecoder(res.Body).Decode(&m); err != nil {
   259  			return false
   260  		}
   261  
   262  		jwksURL, ok := m["jwks_uri"]
   263  		if !ok {
   264  			return false
   265  		}
   266  		*dst = append(*dst, jwksURL.(string))
   267  		return true
   268  	}
   269  	if !getJwksURL(&jwksURLs, "https://oidc-sample.onelogin.com/oidc/2/.well-known/openid-configuration") {
   270  		t.SkipNow()
   271  	}
   272  	if !getJwksURL(&jwksURLs, "https://accounts.google.com/.well-known/openid-configuration") {
   273  		t.SkipNow()
   274  	}
   275  
   276  	ar := jwk.NewAutoRefresh(ctx)
   277  	for _, url := range jwksURLs {
   278  		ar.Configure(url)
   279  	}
   280  
   281  	for _, url := range jwksURLs {
   282  		_, _ = ar.Refresh(ctx, url)
   283  	}
   284  
   285  	for target := range ar.Snapshot() {
   286  		t.Logf("%s last refreshed at %s, next refresh at %s", target.URL, target.LastRefresh, target.NextRefresh)
   287  	}
   288  
   289  	for _, url := range jwksURLs {
   290  		ar.Remove(url)
   291  	}
   292  
   293  	if !assert.Len(t, ar.Snapshot(), 0, `there should be no URLs`) {
   294  		return
   295  	}
   296  
   297  	if !assert.Error(t, ar.Remove(`dummy`), `removing a non-existing url should be an error`) {
   298  		return
   299  	}
   300  }
   301  
   302  func TestErrorSink(t *testing.T) {
   303  	t.Parallel()
   304  
   305  	k, err := jwxtest.GenerateRsaJwk()
   306  	if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
   307  		return
   308  	}
   309  	set := jwk.NewSet()
   310  	set.Add(k)
   311  	testcases := []struct {
   312  		Name    string
   313  		Options func() []jwk.AutoRefreshOption
   314  		Handler http.Handler
   315  	}{
   316  		{
   317  			Name: "non-200 response",
   318  			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   319  				w.WriteHeader(http.StatusForbidden)
   320  			}),
   321  		},
   322  		{
   323  			Name: "invalid JWK",
   324  			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   325  				w.WriteHeader(http.StatusOK)
   326  				w.Write([]byte(`{"empty": "nonthingness"}`))
   327  			}),
   328  		},
   329  		{
   330  			Name: `rejected by whitelist`,
   331  			Options: func() []jwk.AutoRefreshOption {
   332  				return []jwk.AutoRefreshOption{
   333  					jwk.WithFetchWhitelist(jwk.WhitelistFunc(func(_ string) bool {
   334  						return false
   335  					})),
   336  				}
   337  			},
   338  			Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   339  				w.WriteHeader(http.StatusOK)
   340  				json.NewEncoder(w).Encode(k)
   341  			}),
   342  		},
   343  	}
   344  
   345  	for _, tc := range testcases {
   346  		tc := tc
   347  		t.Run(tc.Name, func(t *testing.T) {
   348  			t.Parallel()
   349  			ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   350  			defer cancel()
   351  			srv := httptest.NewServer(tc.Handler)
   352  			defer srv.Close()
   353  
   354  			ar := jwk.NewAutoRefresh(ctx)
   355  
   356  			var options []jwk.AutoRefreshOption
   357  			if f := tc.Options; f != nil {
   358  				options = f()
   359  			}
   360  			options = append(options, jwk.WithRefreshInterval(500*time.Millisecond))
   361  			ar.Configure(srv.URL, options...)
   362  			ch := make(chan jwk.AutoRefreshError, 256) // big buffer
   363  			ar.ErrorSink(ch)
   364  			ar.Fetch(ctx, srv.URL)
   365  
   366  			timer := time.NewTimer(3 * time.Second)
   367  
   368  			select {
   369  			case <-ctx.Done():
   370  				t.Errorf(`ctx.Done before timer`)
   371  			case <-timer.C:
   372  			}
   373  
   374  			cancel() // forcefully end context, and thus the AutoRefresh
   375  
   376  			// timing issues can cause this to be non-deterministic...
   377  			// we'll say it's okay as long as we're in +/- 1 range
   378  			l := len(ch)
   379  			if !assert.True(t, l <= 7, "number of errors shold be less than or equal to 7 (%d)", l) {
   380  				return
   381  			}
   382  			if !assert.True(t, l >= 5, "number of errors shold be greather than or equal to 5 (%d)", l) {
   383  				return
   384  			}
   385  		})
   386  	}
   387  }
   388  
   389  func TestAutoRefreshRace(t *testing.T) {
   390  	k, err := jwxtest.GenerateRsaJwk()
   391  	if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
   392  		return
   393  	}
   394  	set := jwk.NewSet()
   395  	set.Add(k)
   396  
   397  	// set up a server that always success since we need to update the registered target
   398  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   399  		w.WriteHeader(http.StatusOK)
   400  		json.NewEncoder(w).Encode(k)
   401  	}))
   402  	defer srv.Close()
   403  
   404  	// configure a unique auto-refresh
   405  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   406  	defer cancel()
   407  	ar := jwk.NewAutoRefresh(ctx)
   408  	ch := make(chan jwk.AutoRefreshError, 256) // big buffer
   409  	ar.ErrorSink(ch)
   410  
   411  	wg := sync.WaitGroup{}
   412  	routineErr := make(chan error, 20)
   413  
   414  	// execute a bunch of parallel refresh forcing the requests to the server
   415  	// need to simulate configure happening also in the goroutine since this is
   416  	// the cause of races when refresh is updating the registered targets
   417  	for i := 0; i < 5000; i++ {
   418  		wg.Add(1)
   419  		go func() {
   420  			defer wg.Done()
   421  			ctx := context.Background()
   422  
   423  			ar.Configure(srv.URL, jwk.WithRefreshInterval(500*time.Millisecond))
   424  			_, err := ar.Refresh(ctx, srv.URL)
   425  
   426  			if err != nil {
   427  				routineErr <- err
   428  			}
   429  		}()
   430  	}
   431  	wg.Wait()
   432  
   433  	require.Len(t, routineErr, 0)
   434  }
   435  

View as plain text