1 package keyfunc_test
2
3 import (
4 "context"
5 "net/http"
6 "net/http/httptest"
7 "sync/atomic"
8 "testing"
9 "time"
10
11 "github.com/MicahParks/keyfunc/v2"
12 )
13
14 func TestJWKS_Refresh(t *testing.T) {
15 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
16 defer cancel()
17
18 var counter uint64
19 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
20 atomic.AddUint64(&counter, 1)
21 _, err := w.Write([]byte(jwksJSON))
22 if err != nil {
23 http.Error(w, err.Error(), http.StatusInternalServerError)
24 }
25 }))
26 defer server.Close()
27
28 jwksURL := server.URL
29 opts := keyfunc.Options{
30 Ctx: ctx,
31 }
32 jwks, err := keyfunc.Get(jwksURL, opts)
33 if err != nil {
34 t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
35 }
36
37 err = jwks.Refresh(ctx, keyfunc.RefreshOptions{IgnoreRateLimit: true})
38 if err != nil {
39 t.Fatalf(logFmt, "Failed to refresh JWKS.", err)
40 }
41
42 count := atomic.LoadUint64(&counter)
43 if count != 2 {
44 t.Fatalf("Expected 2 refreshes, got %d.", count)
45 }
46 }
47
48 func TestJWKS_RefreshUsingBackgroundGoroutine(t *testing.T) {
49 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
50 defer cancel()
51
52 var counter uint64
53 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
54 atomic.AddUint64(&counter, 1)
55 _, err := w.Write([]byte(jwksJSON))
56 if err != nil {
57 http.Error(w, err.Error(), http.StatusInternalServerError)
58 }
59 }))
60 defer server.Close()
61
62 jwksURL := server.URL
63 opts := keyfunc.Options{
64 Ctx: ctx,
65 RefreshInterval: time.Hour,
66 RefreshRateLimit: time.Hour,
67 }
68 jwks, err := keyfunc.Get(jwksURL, opts)
69 if err != nil {
70 t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
71 }
72
73 err = jwks.Refresh(ctx, keyfunc.RefreshOptions{IgnoreRateLimit: true})
74 if err != nil {
75 t.Fatalf(logFmt, "Failed to refresh JWKS.", err)
76 }
77
78 count := atomic.LoadUint64(&counter)
79 if count != 2 {
80 t.Fatalf("Expected 2 refreshes, got %d.", count)
81 }
82 }
83
84 func TestJWKS_RefreshCancelCtx(t *testing.T) {
85 tests := map[string]struct {
86 provideOptionsCtx bool
87 cancelOptionsCtx bool
88 expectedRefreshes int
89 }{
90 "cancel Options.Ctx": {
91 provideOptionsCtx: true,
92 cancelOptionsCtx: true,
93 expectedRefreshes: 2,
94 },
95 "do not cancel Options.Ctx": {
96 provideOptionsCtx: true,
97 cancelOptionsCtx: false,
98 expectedRefreshes: 3,
99 },
100 "do not provide Options.Ctx": {
101 provideOptionsCtx: false,
102 cancelOptionsCtx: false,
103 expectedRefreshes: 3,
104 },
105 }
106
107 for name, tc := range tests {
108 t.Run(name, func(t *testing.T) {
109 var (
110 ctx context.Context
111 cancel = func() {}
112 )
113 if tc.provideOptionsCtx {
114 ctx, cancel = context.WithCancel(context.Background())
115 defer cancel()
116 }
117
118 var counter uint64
119 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
120 atomic.AddUint64(&counter, 1)
121 _, err := w.Write([]byte(jwksJSON))
122 if err != nil {
123 http.Error(w, err.Error(), http.StatusInternalServerError)
124 }
125 }))
126 defer server.Close()
127
128 jwksURL := server.URL
129 opts := keyfunc.Options{
130 Ctx: ctx,
131 RefreshInterval: 1 * time.Second,
132 }
133 jwks, err := keyfunc.Get(jwksURL, opts)
134 if err != nil {
135 t.Fatalf(logFmt, "Failed to get JWKS from testing URL.", err)
136 }
137
138
139
140 time.Sleep(1100 * time.Millisecond)
141
142 if tc.cancelOptionsCtx {
143 cancel()
144 }
145
146
147
148
149 time.Sleep(1101 * time.Millisecond)
150
151 jwks.EndBackground()
152
153
154
155
156 time.Sleep(1100 * time.Millisecond)
157
158 count := atomic.LoadUint64(&counter)
159 if count != uint64(tc.expectedRefreshes) {
160 t.Fatalf("Expected %d refreshes, got %d.", tc.expectedRefreshes, count)
161 }
162 })
163 }
164 }
165
View as plain text