1
2
3
4
5 package externalaccountauthorizeduser
6
7 import (
8 "context"
9 "encoding/json"
10 "errors"
11 "io/ioutil"
12 "net/http"
13 "net/http/httptest"
14 "testing"
15 "time"
16
17 "golang.org/x/oauth2"
18 "golang.org/x/oauth2/google/internal/stsexchange"
19 )
20
21 const expiryDelta = 10 * time.Second
22
23 var (
24 expiry = time.Unix(234852, 0)
25 testNow = func() time.Time { return expiry }
26 testValid = func(t oauth2.Token) bool {
27 return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
28 }
29 )
30
31 type testRefreshTokenServer struct {
32 URL string
33 Authorization string
34 ContentType string
35 Body string
36 ResponsePayload *stsexchange.Response
37 Response string
38 server *httptest.Server
39 }
40
41 func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
42 config := &Config{
43 Token: "AAAAAAA",
44 Expiry: now().Add(time.Hour),
45 }
46 ts, err := config.TokenSource(context.Background())
47 if err != nil {
48 t.Fatalf("Error getting token source: %v", err)
49 }
50
51 token, err := ts.Token()
52 if err != nil {
53 t.Fatalf("Error retrieving Token: %v", err)
54 }
55 if got, want := token.AccessToken, "AAAAAAA"; got != want {
56 t.Fatalf("Unexpected access token, got %v, want %v", got, want)
57 }
58 }
59
60 func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
61 server := &testRefreshTokenServer{
62 URL: "/",
63 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
64 ContentType: "application/x-www-form-urlencoded",
65 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
66 ResponsePayload: &stsexchange.Response{
67 ExpiresIn: 3600,
68 AccessToken: "AAAAAAA",
69 RefreshToken: "CCCCCCC",
70 },
71 }
72
73 url, err := server.run(t)
74 if err != nil {
75 t.Fatalf("Error starting server")
76 }
77 defer server.close(t)
78
79 config := &Config{
80 RefreshToken: "BBBBBBBBB",
81 TokenURL: url,
82 ClientID: "CLIENT_ID",
83 ClientSecret: "CLIENT_SECRET",
84 }
85 ts, err := config.TokenSource(context.Background())
86 if err != nil {
87 t.Fatalf("Error getting token source: %v", err)
88 }
89
90 token, err := ts.Token()
91 if err != nil {
92 t.Fatalf("Error retrieving Token: %v", err)
93 }
94 if got, want := token.AccessToken, "AAAAAAA"; got != want {
95 t.Fatalf("Unexpected access token, got %v, want %v", got, want)
96 }
97 if config.RefreshToken != "CCCCCCC" {
98 t.Fatalf("Refresh token not updated")
99 }
100 }
101
102 func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
103 server := &testRefreshTokenServer{
104 URL: "/",
105 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
106 ContentType: "application/x-www-form-urlencoded",
107 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
108 ResponsePayload: &stsexchange.Response{
109 ExpiresIn: 3600,
110 AccessToken: "AAAAAAA",
111 },
112 }
113
114 url, err := server.run(t)
115 if err != nil {
116 t.Fatalf("Error starting server")
117 }
118 defer server.close(t)
119
120 config := &Config{
121 RefreshToken: "BBBBBBBBB",
122 TokenURL: url,
123 ClientID: "CLIENT_ID",
124 ClientSecret: "CLIENT_SECRET",
125 }
126 ts, err := config.TokenSource(context.Background())
127 if err != nil {
128 t.Fatalf("Error getting token source: %v", err)
129 }
130
131 token, err := ts.Token()
132 if err != nil {
133 t.Fatalf("Error retrieving Token: %v", err)
134 }
135 if got, want := token.AccessToken, "AAAAAAA"; got != want {
136 t.Fatalf("Unexpected access token, got %v, want %v", got, want)
137 }
138 }
139
140 func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
141 server := &testRefreshTokenServer{
142 URL: "/",
143 Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
144 ContentType: "application/x-www-form-urlencoded",
145 Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB",
146 ResponsePayload: &stsexchange.Response{
147 ExpiresIn: 3600,
148 AccessToken: "AAAAAAA",
149 },
150 }
151
152 url, err := server.run(t)
153 if err != nil {
154 t.Fatalf("Error starting server")
155 }
156 defer server.close(t)
157 testCases := []struct {
158 name string
159 config Config
160 }{
161 {
162 name: "empty config",
163 config: Config{},
164 },
165 {
166 name: "missing refresh token",
167 config: Config{
168 TokenURL: url,
169 ClientID: "CLIENT_ID",
170 ClientSecret: "CLIENT_SECRET",
171 },
172 },
173 {
174 name: "missing token url",
175 config: Config{
176 RefreshToken: "BBBBBBBBB",
177 ClientID: "CLIENT_ID",
178 ClientSecret: "CLIENT_SECRET",
179 },
180 },
181 {
182 name: "missing client id",
183 config: Config{
184 RefreshToken: "BBBBBBBBB",
185 TokenURL: url,
186 ClientSecret: "CLIENT_SECRET",
187 },
188 },
189 {
190 name: "missing client secrect",
191 config: Config{
192 RefreshToken: "BBBBBBBBB",
193 TokenURL: url,
194 ClientID: "CLIENT_ID",
195 },
196 },
197 }
198 for _, tc := range testCases {
199 t.Run(tc.name, func(t *testing.T) {
200
201 expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
202 _, err := tc.config.TokenSource((context.Background()))
203 if err == nil {
204 t.Fatalf("Expected error, but received none")
205 }
206 if got := err.Error(); got != expectErrMsg {
207 t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
208 }
209 })
210 }
211 }
212
213 func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
214 t.Helper()
215 if trts.server != nil {
216 return "", errors.New("Server is already running")
217 }
218 trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
219 if got, want := r.URL.String(), trts.URL; got != want {
220 t.Errorf("URL.String(): got %v but want %v", got, want)
221 }
222 headerAuth := r.Header.Get("Authorization")
223 if got, want := headerAuth, trts.Authorization; got != want {
224 t.Errorf("got %v but want %v", got, want)
225 }
226 headerContentType := r.Header.Get("Content-Type")
227 if got, want := headerContentType, trts.ContentType; got != want {
228 t.Errorf("got %v but want %v", got, want)
229 }
230 body, err := ioutil.ReadAll(r.Body)
231 if err != nil {
232 t.Fatalf("Failed reading request body: %s.", err)
233 }
234 if got, want := string(body), trts.Body; got != want {
235 t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
236 }
237 w.Header().Set("Content-Type", "application/json")
238 if trts.ResponsePayload != nil {
239 content, err := json.Marshal(trts.ResponsePayload)
240 if err != nil {
241 t.Fatalf("unable to marshall response JSON")
242 }
243 w.Write(content)
244 } else {
245 w.Write([]byte(trts.Response))
246 }
247 }))
248 return trts.server.URL, nil
249 }
250
251 func (trts *testRefreshTokenServer) close(t *testing.T) error {
252 t.Helper()
253 if trts.server == nil {
254 return errors.New("No server is running")
255 }
256 trts.server.Close()
257 trts.server = nil
258 return nil
259 }
260
View as plain text