1
2
3
4
5 package authhandler
6
7 import (
8 "context"
9 "fmt"
10 "net/http"
11 "net/http/httptest"
12 "testing"
13
14 "golang.org/x/oauth2"
15 )
16
17 func TestTokenExchange_Success(t *testing.T) {
18 authhandler := func(authCodeURL string) (string, string, error) {
19 if authCodeURL == "testAuthCodeURL?client_id=testClientID&response_type=code&scope=pubsub&state=testState" {
20 return "testCode", "testState", nil
21 }
22 return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
23 }
24
25 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26 r.ParseForm()
27 if r.Form.Get("code") == "testCode" {
28 w.Header().Set("Content-Type", "application/json")
29 w.Write([]byte(`{
30 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
31 "scope": "pubsub",
32 "token_type": "bearer",
33 "expires_in": 3600
34 }`))
35 }
36 }))
37 defer ts.Close()
38
39 conf := &oauth2.Config{
40 ClientID: "testClientID",
41 Scopes: []string{"pubsub"},
42 Endpoint: oauth2.Endpoint{
43 AuthURL: "testAuthCodeURL",
44 TokenURL: ts.URL,
45 },
46 }
47
48 tok, err := TokenSource(context.Background(), conf, "testState", authhandler).Token()
49 if err != nil {
50 t.Fatal(err)
51 }
52 if !tok.Valid() {
53 t.Errorf("got invalid token: %v", tok)
54 }
55 if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
56 t.Errorf("access token = %q; want %q", got, want)
57 }
58 if got, want := tok.TokenType, "bearer"; got != want {
59 t.Errorf("token type = %q; want %q", got, want)
60 }
61 if got := tok.Expiry.IsZero(); got {
62 t.Errorf("token expiry is zero = %v, want false", got)
63 }
64 scope := tok.Extra("scope")
65 if got, want := scope, "pubsub"; got != want {
66 t.Errorf("scope = %q; want %q", got, want)
67 }
68 }
69
70 func TestTokenExchange_StateMismatch(t *testing.T) {
71 authhandler := func(authCodeURL string) (string, string, error) {
72 return "testCode", "testStateMismatch", nil
73 }
74
75 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
76 w.Header().Set("Content-Type", "application/json")
77 w.Write([]byte(`{
78 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
79 "scope": "pubsub",
80 "token_type": "bearer",
81 "expires_in": 3600
82 }`))
83 }))
84 defer ts.Close()
85
86 conf := &oauth2.Config{
87 ClientID: "testClientID",
88 Scopes: []string{"pubsub"},
89 Endpoint: oauth2.Endpoint{
90 AuthURL: "testAuthCodeURL",
91 TokenURL: ts.URL,
92 },
93 }
94
95 _, err := TokenSource(context.Background(), conf, "testState", authhandler).Token()
96 if want_err := "state mismatch in 3-legged-OAuth flow"; err == nil || err.Error() != want_err {
97 t.Errorf("err = %q; want %q", err, want_err)
98 }
99 }
100
101 func TestTokenExchangeWithPKCE_Success(t *testing.T) {
102 authhandler := func(authCodeURL string) (string, string, error) {
103 if authCodeURL == "testAuthCodeURL?client_id=testClientID&code_challenge=codeChallenge&code_challenge_method=plain&response_type=code&scope=pubsub&state=testState" {
104 return "testCode", "testState", nil
105 }
106 return "", "", fmt.Errorf("invalid authCodeURL: %q", authCodeURL)
107 }
108
109 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
110 r.ParseForm()
111 if r.Form.Get("code") == "testCode" && r.Form.Get("code_verifier") == "codeChallenge" {
112 w.Header().Set("Content-Type", "application/json")
113 w.Write([]byte(`{
114 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
115 "scope": "pubsub",
116 "token_type": "bearer",
117 "expires_in": 3600
118 }`))
119 }
120 }))
121 defer ts.Close()
122
123 conf := &oauth2.Config{
124 ClientID: "testClientID",
125 Scopes: []string{"pubsub"},
126 Endpoint: oauth2.Endpoint{
127 AuthURL: "testAuthCodeURL",
128 TokenURL: ts.URL,
129 },
130 }
131 pkce := PKCEParams{
132 Challenge: "codeChallenge",
133 ChallengeMethod: "plain",
134 Verifier: "codeChallenge",
135 }
136
137 tok, err := TokenSourceWithPKCE(context.Background(), conf, "testState", authhandler, &pkce).Token()
138 if err != nil {
139 t.Fatal(err)
140 }
141 if !tok.Valid() {
142 t.Errorf("got invalid token: %v", tok)
143 }
144 if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
145 t.Errorf("access token = %q; want %q", got, want)
146 }
147 if got, want := tok.TokenType, "bearer"; got != want {
148 t.Errorf("token type = %q; want %q", got, want)
149 }
150 if got := tok.Expiry.IsZero(); got {
151 t.Errorf("token expiry is zero = %v, want false", got)
152 }
153 scope := tok.Extra("scope")
154 if got, want := scope, "pubsub"; got != want {
155 t.Errorf("scope = %q; want %q", got, want)
156 }
157 }
158
View as plain text