1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package auth
16
17 import (
18 "context"
19 "fmt"
20 "io"
21 "net/http"
22 "net/http/httptest"
23 "net/url"
24 "testing"
25 "time"
26 )
27
28 const day = 24 * time.Hour
29
30 func newOpts(url string) *Options3LO {
31 return &Options3LO{
32 ClientID: "CLIENT_ID",
33 ClientSecret: "CLIENT_SECRET",
34 RedirectURL: "REDIRECT_URL",
35 Scopes: []string{"scope1", "scope2"},
36 AuthURL: url + "/auth",
37 TokenURL: url + "/token",
38 AuthStyle: StyleInHeader,
39 RefreshToken: "OLD_REFRESH_TOKEN",
40 }
41 }
42
43 func Test3LO_URLUnsafe(t *testing.T) {
44 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
45 if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want {
46 t.Errorf("Authorization header = %q; want %q", got, want)
47 }
48
49 w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
50 w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
51 }))
52 defer ts.Close()
53 conf := newOpts(ts.URL)
54 conf.ClientID = "CLIENT_ID??"
55 conf.ClientSecret = "CLIENT_SECRET??"
56 _, _, err := conf.exchange(context.Background(), "exchange-code")
57 if err != nil {
58 t.Error(err)
59 }
60 }
61
62 func Test3LO_StandardExchange(t *testing.T) {
63 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
64 if r.URL.String() != "/token" {
65 t.Errorf("Unexpected exchange request URL %q", r.URL)
66 }
67 headerAuth := r.Header.Get("Authorization")
68 if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want {
69 t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want)
70 }
71 headerContentType := r.Header.Get("Content-Type")
72 if headerContentType != "application/x-www-form-urlencoded" {
73 t.Errorf("Unexpected Content-Type header %q", headerContentType)
74 }
75 body, err := io.ReadAll(r.Body)
76 if err != nil {
77 t.Errorf("Failed reading request body: %s.", err)
78 }
79 if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
80 t.Errorf("Unexpected exchange payload; got %q", body)
81 }
82 w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
83 w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
84 }))
85 defer ts.Close()
86 conf := newOpts(ts.URL)
87 tok, _, err := conf.exchange(context.Background(), "exchange-code")
88 if err != nil {
89 t.Error(err)
90 }
91 if !tok.IsValid() {
92 t.Fatalf("Token invalid. Got: %#v", tok)
93 }
94 if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" {
95 t.Errorf("Unexpected access token, %#v.", tok.Value)
96 }
97 if tok.Type != "bearer" {
98 t.Errorf("Unexpected token type, %#v.", tok.Type)
99 }
100 scope := tok.Metadata["scope"].([]string)
101 if scope[0] != "user" {
102 t.Errorf("Unexpected value for scope: %v", scope)
103 }
104 }
105
106 func Test3LO_ExchangeCustomParams(t *testing.T) {
107 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
108 if r.URL.String() != "/token" {
109 t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
110 }
111 headerAuth := r.Header.Get("Authorization")
112 if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
113 t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
114 }
115 headerContentType := r.Header.Get("Content-Type")
116 if headerContentType != "application/x-www-form-urlencoded" {
117 t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
118 }
119 body, err := io.ReadAll(r.Body)
120 if err != nil {
121 t.Errorf("Failed reading request body: %s.", err)
122 }
123 if string(body) != "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
124 t.Errorf("Unexpected exchange payload, %v is found.", string(body))
125 }
126 w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
127 w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
128 }))
129 defer ts.Close()
130 conf := newOpts(ts.URL)
131 conf.URLParams = url.Values{}
132 conf.URLParams.Set("foo", "bar")
133
134 tok, _, err := conf.exchange(context.Background(), "exchange-code")
135 if err != nil {
136 t.Error(err)
137 }
138 if !tok.IsValid() {
139 t.Fatalf("Token invalid. Got: %#v", tok)
140 }
141 if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" {
142 t.Errorf("Unexpected access token, %#v.", tok.Value)
143 }
144 if tok.Type != "bearer" {
145 t.Errorf("Unexpected token type, %#v.", tok.Type)
146 }
147 scope := tok.Metadata["scope"].([]string)
148 if scope[0] != "user" {
149 t.Errorf("Unexpected value for scope: %v", scope)
150 }
151 }
152
153 func Test3LO_ExchangeJSONResponse(t *testing.T) {
154 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
155 if r.URL.String() != "/token" {
156 t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
157 }
158 headerAuth := r.Header.Get("Authorization")
159 if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
160 t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
161 }
162 headerContentType := r.Header.Get("Content-Type")
163 if headerContentType != "application/x-www-form-urlencoded" {
164 t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
165 }
166 body, err := io.ReadAll(r.Body)
167 if err != nil {
168 t.Errorf("Failed reading request body: %s.", err)
169 }
170 if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
171 t.Errorf("Unexpected exchange payload, %v is found.", string(body))
172 }
173 w.Header().Set("Content-Type", "application/json")
174 w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`))
175 }))
176 defer ts.Close()
177 conf := newOpts(ts.URL)
178 tok, _, err := conf.exchange(context.Background(), "exchange-code")
179 if err != nil {
180 t.Error(err)
181 }
182 if !tok.IsValid() {
183 t.Fatalf("Token invalid. Got: %#v", tok)
184 }
185 if tok.Value != "90d64460d14870c08c81352a05dedd3465940a7c" {
186 t.Errorf("Unexpected access token, %#v.", tok.Value)
187 }
188 if tok.Type != "bearer" {
189 t.Errorf("Unexpected token type, %#v.", tok.Type)
190 }
191 scope := tok.Metadata["scope"].(string)
192 if scope != "user" {
193 t.Errorf("Unexpected value for scope: %v", scope)
194 }
195 expiresIn := tok.Metadata["expires_in"]
196 if expiresIn != float64(86400) {
197 t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn)
198 }
199 }
200
201 func Test3LO_ExchangeJSONResponseExpiry(t *testing.T) {
202 seconds := int32(day.Seconds())
203 for _, c := range []struct {
204 name string
205 expires string
206 want bool
207 nullExpires bool
208 }{
209 {"normal", fmt.Sprintf(`"expires_in": %d`, seconds), true, false},
210 {"null", `"expires_in": null`, true, true},
211 {"wrong_type", `"expires_in": false`, false, false},
212 {"wrong_type2", `"expires_in": {}`, false, false},
213 {"wrong_value", `"expires_in": "zzz"`, false, false},
214 } {
215 t.Run(c.name, func(t *testing.T) {
216 test3LOExchangeJSONResponseExpiry(t, c.expires, c.want, c.nullExpires)
217 })
218 }
219 }
220
221 func test3LOExchangeJSONResponseExpiry(t *testing.T, exp string, want, nullExpires bool) {
222 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
223 w.Header().Set("Content-Type", "application/json")
224 w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp)))
225 }))
226 defer ts.Close()
227 conf := newOpts(ts.URL)
228 t1 := time.Now().Add(day)
229 tok, _, err := conf.exchange(context.Background(), "exchange-code")
230 t2 := t1.Add(day)
231
232 if got := (err == nil); got != want {
233 if want {
234 t.Errorf("unexpected error: got %v", err)
235 } else {
236 t.Errorf("unexpected success")
237 }
238 }
239 if !want {
240 return
241 }
242 if !tok.IsValid() {
243 t.Fatalf("Token invalid. Got: %#v", tok)
244 }
245 expiry := tok.Expiry
246
247 if nullExpires && expiry.IsZero() {
248 return
249 }
250 if expiry.Before(t1) || expiry.After(t2) {
251 t.Errorf("Unexpected value for Expiry: %v (should be between %v and %v)", expiry, t1, t2)
252 }
253 }
254
255 func Test3LO_ExchangeBadResponse(t *testing.T) {
256 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
257 w.Header().Set("Content-Type", "application/json")
258 w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
259 }))
260 defer ts.Close()
261 conf := newOpts(ts.URL)
262 _, _, err := conf.exchange(context.Background(), "code")
263 if err == nil {
264 t.Error("expected error from missing access_token")
265 }
266 }
267
268 func Test3LO_ExchangeBadResponseType(t *testing.T) {
269 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
270 w.Header().Set("Content-Type", "application/json")
271 w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
272 }))
273 defer ts.Close()
274 conf := newOpts(ts.URL)
275 _, _, err := conf.exchange(context.Background(), "exchange-code")
276 if err == nil {
277 t.Error("expected error from non-string access_token")
278 }
279 }
280
281 func Test3LO_RefreshTokenReplacement(t *testing.T) {
282 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
283 w.Header().Set("Content-Type", "application/json")
284 w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer", "refresh_token": "NEW_REFRESH_TOKEN"}`))
285 }))
286 defer ts.Close()
287 opts := newOpts(ts.URL)
288 tp, err := New3LOTokenProvider(opts)
289 if err != nil {
290 t.Fatal(err)
291 }
292 if _, err := tp.Token(context.Background()); err != nil {
293 t.Errorf("got err = %v; want none", err)
294 return
295 }
296 innerTP := tp.(*cachedTokenProvider).tp.(*tokenProvider3LO)
297 if want := "NEW_REFRESH_TOKEN"; innerTP.refreshToken != want {
298 t.Errorf("RefreshToken = %q; want %q", innerTP.refreshToken, want)
299 }
300 }
301
302 func Test3LO_RefreshTokenPreservation(t *testing.T) {
303 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
304 w.Header().Set("Content-Type", "application/json")
305 w.Write([]byte(`{"access_token":"ACCESS_TOKEN", "scope": "user", "token_type": "bearer"}`))
306 }))
307 defer ts.Close()
308 opts := newOpts(ts.URL)
309 const oldRefreshToken = "OLD_REFRESH_TOKEN"
310 tp, err := New3LOTokenProvider(opts)
311 if err != nil {
312 t.Fatal(err)
313 }
314 if _, err := tp.Token(context.Background()); err != nil {
315 t.Errorf("got err = %v; want none", err)
316 return
317 }
318 innerTP := tp.(*cachedTokenProvider).tp.(*tokenProvider3LO)
319 if innerTP.refreshToken != oldRefreshToken {
320 t.Errorf("RefreshToken = %q; want %q", innerTP.refreshToken, oldRefreshToken)
321 }
322 }
323
324 func Test3LO_AuthHandlerExchangeSuccess(t *testing.T) {
325 authhandler := func(authCodeURL string) (string, string, error) {
326 if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" {
327 return "testCode", "testState", nil
328 }
329 return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
330 }
331
332 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
333 r.ParseForm()
334 if r.Form.Get("code") == "testCode" {
335 w.Header().Set("Content-Type", "application/json")
336 w.Write([]byte(`{
337 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
338 "scope": "pubsub",
339 "token_type": "bearer",
340 "expires_in": 3600
341 }`))
342 }
343 }))
344 defer ts.Close()
345
346 opts := &Options3LO{
347 ClientID: "testClientID",
348 Scopes: []string{"pubsub"},
349 AuthURL: "testAuthCodeURL",
350 TokenURL: ts.URL,
351 AuthStyle: StyleInHeader,
352 AuthHandlerOpts: &AuthorizationHandlerOptions{
353 State: "testState",
354 Handler: authhandler,
355 },
356 }
357
358 tp, err := New3LOTokenProvider(opts)
359 if err != nil {
360 t.Fatal(err)
361 }
362 tok, err := tp.Token(context.Background())
363 if err != nil {
364 t.Fatal(err)
365 }
366 if !tok.IsValid() {
367 t.Errorf("got invalid token: %v", tok)
368 }
369 if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
370 t.Errorf("access token = %q; want %q", got, want)
371 }
372 if got, want := tok.Type, "bearer"; got != want {
373 t.Errorf("token type = %q; want %q", got, want)
374 }
375 if got := tok.Expiry.IsZero(); got {
376 t.Errorf("token expiry is zero = %v, want false", got)
377 }
378 scope := tok.Metadata["scope"].(string)
379 if got, want := scope, "pubsub"; got != want {
380 t.Errorf("scope = %q; want %q", got, want)
381 }
382 }
383
384 func Test3LO_AuthHandlerExchangeStateMismatch(t *testing.T) {
385 authhandler := func(authCodeURL string) (string, string, error) {
386 return "testCode", "testStateMismatch", nil
387 }
388
389 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
390 w.Header().Set("Content-Type", "application/json")
391 w.Write([]byte(`{
392 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
393 "scope": "pubsub",
394 "token_type": "bearer",
395 "expires_in": 3600
396 }`))
397 }))
398 defer ts.Close()
399
400 opts := &Options3LO{
401 ClientID: "testClientID",
402 Scopes: []string{"pubsub"},
403 AuthURL: "testAuthCodeURL",
404 TokenURL: ts.URL,
405 AuthStyle: StyleInParams,
406 AuthHandlerOpts: &AuthorizationHandlerOptions{
407 State: "testState",
408 Handler: authhandler,
409 },
410 }
411 tp, err := New3LOTokenProvider(opts)
412 if err != nil {
413 t.Fatal(err)
414 }
415 _, err = tp.Token(context.Background())
416 if wantErr := "auth: state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != wantErr {
417 t.Errorf("err = %q; want %q", err, wantErr)
418 }
419 }
420
421 func Test3LO_PKCEExchangeWithSuccess(t *testing.T) {
422 authhandler := func(authCodeURL string) (string, string, error) {
423 if authCodeURL == "testAuthCodeURL?client_id=testClientID&code_challenge=codeChallenge&code_challenge_method=plain&response_type=code&scope=pubsub&state=testState" {
424 return "testCode", "testState", nil
425 }
426 return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
427 }
428
429 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
430 r.ParseForm()
431 if r.Form.Get("code") == "testCode" && r.Form.Get("code_verifier") == "codeChallenge" {
432 w.Header().Set("Content-Type", "application/json")
433 w.Write([]byte(`{
434 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
435 "scope": "pubsub",
436 "token_type": "bearer",
437 "expires_in": 3600
438 }`))
439 }
440 }))
441 defer ts.Close()
442
443 opts := &Options3LO{
444 ClientID: "testClientID",
445 Scopes: []string{"pubsub"},
446 AuthURL: "testAuthCodeURL",
447 TokenURL: ts.URL,
448 AuthStyle: StyleInParams,
449 AuthHandlerOpts: &AuthorizationHandlerOptions{
450 State: "testState",
451 Handler: authhandler,
452 PKCEOpts: &PKCEOptions{
453 Challenge: "codeChallenge",
454 ChallengeMethod: "plain",
455 Verifier: "codeChallenge",
456 },
457 },
458 }
459
460 tp, err := New3LOTokenProvider(opts)
461 if err != nil {
462 t.Fatal(err)
463 }
464 tok, err := tp.Token(context.Background())
465 if err != nil {
466 t.Fatal(err)
467 }
468 if !tok.IsValid() {
469 t.Errorf("got invalid token: %v", tok)
470 }
471 if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
472 t.Errorf("access token = %q; want %q", got, want)
473 }
474 if got, want := tok.Type, "bearer"; got != want {
475 t.Errorf("token type = %q; want %q", got, want)
476 }
477 if got := tok.Expiry.IsZero(); got {
478 t.Errorf("token expiry is zero = %v, want false", got)
479 }
480 scope := tok.Metadata["scope"].(string)
481 if got, want := scope, "pubsub"; got != want {
482 t.Errorf("scope = %q; want %q", got, want)
483 }
484 }
485
486 func Test3LO_Validate(t *testing.T) {
487 tests := []struct {
488 name string
489 opts *Options3LO
490 }{
491 {
492 name: "missing options",
493 },
494 {
495 name: "missing client ID",
496 opts: &Options3LO{
497 ClientSecret: "client_secret",
498 AuthURL: "auth_url",
499 TokenURL: "token_url",
500 AuthStyle: StyleInHeader,
501 RefreshToken: "refreshing",
502 },
503 },
504 {
505 name: "missing client secret",
506 opts: &Options3LO{
507 ClientID: "client_id",
508 AuthURL: "auth_url",
509 TokenURL: "token_url",
510 AuthStyle: StyleInHeader,
511 RefreshToken: "refreshing",
512 },
513 },
514 {
515 name: "missing auth URL",
516 opts: &Options3LO{
517 ClientID: "client_id",
518 ClientSecret: "client_secret",
519 TokenURL: "token_url",
520 AuthStyle: StyleInHeader,
521 RefreshToken: "refreshing",
522 },
523 },
524 {
525 name: "missing token URL",
526 opts: &Options3LO{
527 ClientID: "client_id",
528 ClientSecret: "client_secret",
529 AuthURL: "auth_url",
530 AuthStyle: StyleInHeader,
531 RefreshToken: "refreshing",
532 },
533 },
534 {
535 name: "missing auth style",
536 opts: &Options3LO{
537 ClientID: "client_id",
538 ClientSecret: "client_secret",
539 AuthURL: "auth_url",
540 TokenURL: "token_url",
541 RefreshToken: "refreshing",
542 },
543 },
544 {
545 name: "missing refresh token",
546 opts: &Options3LO{
547 ClientID: "client_id",
548 ClientSecret: "client_secret",
549 AuthURL: "auth_url",
550 TokenURL: "token_url",
551 AuthStyle: StyleInHeader,
552 },
553 },
554 }
555 for _, tt := range tests {
556 t.Run(tt.name, func(t *testing.T) {
557 if _, err := New3LOTokenProvider(tt.opts); err == nil {
558 t.Error("got nil, want an error")
559 }
560 })
561 }
562 }
563
View as plain text