1
16
17 package transport
18
19 import (
20 "fmt"
21 "net/http"
22 "reflect"
23 "sync"
24 "testing"
25 "time"
26
27 "golang.org/x/oauth2"
28 )
29
30 type testTokenSource struct {
31 calls int
32 tok *oauth2.Token
33 err error
34 }
35
36 func (ts *testTokenSource) Token() (*oauth2.Token, error) {
37 ts.calls++
38 return ts.tok, ts.err
39 }
40
41 func TestCachingTokenSource(t *testing.T) {
42 start := time.Now()
43 tokA := &oauth2.Token{
44 AccessToken: "a",
45 Expiry: start.Add(10 * time.Minute),
46 }
47 tokB := &oauth2.Token{
48 AccessToken: "b",
49 Expiry: start.Add(20 * time.Minute),
50 }
51 tests := []struct {
52 name string
53
54 tok *oauth2.Token
55 tsTok *oauth2.Token
56 tsErr error
57 wait time.Duration
58
59 wantTok *oauth2.Token
60 wantErr bool
61 wantTSCalls int
62 }{
63 {
64 name: "valid token returned from cache",
65 tok: tokA,
66 wantTok: tokA,
67 },
68 {
69 name: "valid token returned from cache 1 minute before scheduled refresh",
70 tok: tokA,
71 wait: 8 * time.Minute,
72 wantTok: tokA,
73 },
74 {
75 name: "new token created when cache is empty",
76 tsTok: tokA,
77 wantTok: tokA,
78 wantTSCalls: 1,
79 },
80 {
81 name: "new token created 1 minute after scheduled refresh",
82 tok: tokA,
83 tsTok: tokB,
84 wait: 10 * time.Minute,
85 wantTok: tokB,
86 wantTSCalls: 1,
87 },
88 {
89 name: "error on create token returns error",
90 tsErr: fmt.Errorf("error"),
91 wantErr: true,
92 wantTSCalls: 1,
93 },
94 }
95 for _, c := range tests {
96 t.Run(c.name, func(t *testing.T) {
97 tts := &testTokenSource{
98 tok: c.tsTok,
99 err: c.tsErr,
100 }
101
102 ts := &cachingTokenSource{
103 base: tts,
104 tok: c.tok,
105 leeway: 1 * time.Minute,
106 now: func() time.Time { return start.Add(c.wait) },
107 }
108
109 gotTok, gotErr := ts.Token()
110 if got, want := gotTok, c.wantTok; !reflect.DeepEqual(got, want) {
111 t.Errorf("unexpected token:\n\tgot:\t%#v\n\twant:\t%#v", got, want)
112 }
113 if got, want := tts.calls, c.wantTSCalls; got != want {
114 t.Errorf("unexpected number of Token() calls: got %d, want %d", got, want)
115 }
116 if gotErr == nil && c.wantErr {
117 t.Errorf("wanted error but got none")
118 }
119 if gotErr != nil && !c.wantErr {
120 t.Errorf("unexpected error: %v", gotErr)
121 }
122 })
123 }
124 }
125
126 func TestCachingTokenSourceRace(t *testing.T) {
127 for i := 0; i < 100; i++ {
128 tts := &testTokenSource{
129 tok: &oauth2.Token{
130 AccessToken: "a",
131 Expiry: time.Now().Add(1000 * time.Hour),
132 },
133 }
134
135 ts := &cachingTokenSource{
136 now: time.Now,
137 base: tts,
138 leeway: 1 * time.Minute,
139 }
140
141 var wg sync.WaitGroup
142 wg.Add(100)
143 errc := make(chan error, 100)
144
145 for i := 0; i < 100; i++ {
146 go func() {
147 defer wg.Done()
148 if _, err := ts.Token(); err != nil {
149 errc <- err
150 }
151 }()
152 }
153 go func() {
154 wg.Wait()
155 close(errc)
156 }()
157 if err, ok := <-errc; ok {
158 t.Fatalf("err: %v", err)
159 }
160 if tts.calls != 1 {
161 t.Errorf("expected one call to Token() but saw: %d", tts.calls)
162 }
163 }
164 }
165
166 func TestTokenSourceTransportRoundTrip(t *testing.T) {
167 goodToken := &oauth2.Token{
168 AccessToken: "good",
169 Expiry: time.Now().Add(1000 * time.Hour),
170 }
171 badToken := &oauth2.Token{
172 AccessToken: "bad",
173 Expiry: time.Now().Add(1000 * time.Hour),
174 }
175 tests := []struct {
176 name string
177 header http.Header
178 token *oauth2.Token
179 cachedToken *oauth2.Token
180 wantCalls int
181 wantCaching bool
182 }{
183 {
184 name: "skip oauth rt if has authorization header",
185 header: map[string][]string{"Authorization": {"Bearer TOKEN"}},
186 token: goodToken,
187 },
188 {
189 name: "authorized on newly acquired good token",
190 token: goodToken,
191 wantCalls: 1,
192 wantCaching: true,
193 },
194 {
195 name: "authorized on cached good token",
196 token: goodToken,
197 cachedToken: goodToken,
198 wantCalls: 0,
199 wantCaching: true,
200 },
201 {
202 name: "unauthorized on newly acquired bad token",
203 token: badToken,
204 wantCalls: 1,
205 wantCaching: true,
206 },
207 {
208 name: "unauthorized on cached bad token",
209 token: badToken,
210 cachedToken: badToken,
211 wantCalls: 0,
212 },
213 }
214 for _, test := range tests {
215 t.Run(test.name, func(t *testing.T) {
216 tts := &testTokenSource{
217 tok: test.token,
218 }
219 cachedTokenSource := NewCachedTokenSource(tts)
220 cachedTokenSource.tok = test.cachedToken
221
222 rt := ResettableTokenSourceWrapTransport(cachedTokenSource)(&testTransport{})
223
224 rt.RoundTrip(&http.Request{Header: test.header})
225 if tts.calls != test.wantCalls {
226 t.Errorf("RoundTrip() called Token() = %d times, want %d", tts.calls, test.wantCalls)
227 }
228
229 if (cachedTokenSource.tok != nil) != test.wantCaching {
230 t.Errorf("Got caching %v, want caching %v", cachedTokenSource != nil, test.wantCaching)
231 }
232 })
233 }
234 }
235
236 type uncancellableRT struct {
237 rt http.RoundTripper
238 }
239
240 func (urt *uncancellableRT) RoundTrip(req *http.Request) (*http.Response, error) {
241 return urt.rt.RoundTrip(req)
242 }
243
244 func TestTokenSourceTransportCancelRequest(t *testing.T) {
245 tests := []struct {
246 name string
247 header http.Header
248 wrapTransport func(http.RoundTripper) http.RoundTripper
249 expectCancel bool
250 }{
251 {
252 name: "cancel req with bearer token skips oauth rt",
253 header: map[string][]string{"Authorization": {"Bearer TOKEN"}},
254 expectCancel: true,
255 },
256 {
257 name: "can't cancel request with rts that doesn't implent unwrap or cancel",
258 wrapTransport: func(rt http.RoundTripper) http.RoundTripper {
259 return &uncancellableRT{rt: rt}
260 },
261 expectCancel: false,
262 },
263 }
264 for _, test := range tests {
265 t.Run(test.name, func(t *testing.T) {
266 baseRecorder := &testTransport{}
267
268 var base http.RoundTripper = baseRecorder
269 if test.wrapTransport != nil {
270 base = test.wrapTransport(base)
271 }
272
273 rt := &tokenSourceTransport{
274 base: base,
275 ort: &oauth2.Transport{
276 Base: base,
277 },
278 }
279
280 rt.CancelRequest(&http.Request{
281 Header: test.header,
282 })
283
284 if baseRecorder.canceled != test.expectCancel {
285 t.Errorf("unexpected cancel: got=%v, want=%v", baseRecorder.canceled, test.expectCancel)
286 }
287 })
288 }
289 }
290
291 type testTransport struct {
292 canceled bool
293 base http.RoundTripper
294 }
295
296 func (rt *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
297 if req.Header["Authorization"][0] == "Bearer bad" {
298 return &http.Response{StatusCode: 401}, nil
299 }
300 return nil, nil
301 }
302
303 func (rt *testTransport) CancelRequest(req *http.Request) {
304 rt.canceled = true
305 if rt.base != nil {
306 tryCancelRequest(rt.base, req)
307 }
308 }
309
View as plain text