1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package auth
16
17 import (
18 "context"
19 "encoding/base64"
20 "encoding/json"
21 "fmt"
22 "net/http"
23 "net/http/httptest"
24 "strings"
25 "testing"
26 "time"
27
28 "cloud.google.com/go/auth/internal/jwt"
29 "github.com/google/go-cmp/cmp"
30 )
31
32 var fakePrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
33 MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE
34 DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY
35 fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK
36 1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr
37 k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9
38 /E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt
39 3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn
40 2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3
41 nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK
42 6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf
43 5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e
44 DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1
45 M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g
46 z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y
47 1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK
48 J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U
49 f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx
50 QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA
51 cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr
52 Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw
53 5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg
54 KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84
55 OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd
56 mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
57 5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg==
58 -----END RSA PRIVATE KEY-----`)
59
60 func TestError_Temporary(t *testing.T) {
61 tests := []struct {
62 name string
63 code int
64 want bool
65 }{
66 {
67 name: "temporary with 500",
68 code: http.StatusInternalServerError,
69 want: true,
70 },
71 {
72 name: "temporary with 503",
73 code: http.StatusServiceUnavailable,
74 want: true,
75 },
76 {
77 name: "temporary with 408",
78 code: http.StatusRequestTimeout,
79 want: true,
80 },
81 {
82 name: "temporary with 429",
83 code: http.StatusTooManyRequests,
84 want: true,
85 },
86 {
87 name: "temporary with 418",
88 code: http.StatusTeapot,
89 want: false,
90 },
91 }
92 for _, tt := range tests {
93 t.Run(tt.name, func(t *testing.T) {
94 ae := &Error{
95 Response: &http.Response{
96 StatusCode: tt.code,
97 },
98 }
99 if got := ae.Temporary(); got != tt.want {
100 t.Errorf("Temporary() = %v; want %v", got, tt.want)
101 }
102 })
103 }
104 }
105
106 func TestToken_isValidWithEarlyExpiry(t *testing.T) {
107 now := time.Now()
108 timeNow = func() time.Time { return now }
109 defer func() { timeNow = time.Now }()
110
111 cases := []struct {
112 name string
113 tok *Token
114 expiry time.Duration
115 want bool
116 }{
117 {name: "4 minutes", tok: &Token{Expiry: now.Add(4 * 60 * time.Second)}, expiry: defaultExpiryDelta, want: true},
118 {name: "3 minutes and 45 seconds", tok: &Token{Expiry: now.Add(defaultExpiryDelta)}, expiry: defaultExpiryDelta, want: true},
119 {name: "3 minutes and 45 seconds-1ns", tok: &Token{Expiry: now.Add(defaultExpiryDelta - 1*time.Nanosecond)}, expiry: defaultExpiryDelta, want: false},
120 {name: "-1 hour", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, expiry: defaultExpiryDelta, want: false},
121 {name: "12 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(12 * time.Second)}, expiry: time.Second * 5, want: true},
122 {name: "5 seconds, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second * 5)}, expiry: time.Second * 5, want: true},
123 {name: "5 seconds-1ns, custom expiryDelta", tok: &Token{Expiry: now.Add(time.Second*5 - 1*time.Nanosecond)}, expiry: time.Second * 5, want: false},
124 {name: "-1 hour, custom expiryDelta", tok: &Token{Expiry: now.Add(-1 * time.Hour)}, expiry: time.Second * 5, want: false},
125 }
126 for _, tc := range cases {
127 tc.tok.Value = "tok"
128 if got, want := tc.tok.isValidWithEarlyExpiry(tc.expiry), tc.want; got != want {
129 t.Errorf("expired (%q) = %v; want %v", tc.name, got, want)
130 }
131 }
132 }
133
134 func TestError_Error(t *testing.T) {
135
136 tests := []struct {
137 name string
138
139 Response *http.Response
140 Body []byte
141 Err error
142 code string
143 description string
144 uri string
145
146 want string
147 }{
148 {
149 name: "basic",
150 Response: &http.Response{
151 StatusCode: http.StatusTeapot,
152 },
153 Body: []byte("I'm a teapot"),
154 want: "auth: cannot fetch token: 418\nResponse: I'm a teapot",
155 },
156 {
157 name: "from query",
158 code: fmt.Sprint(http.StatusTeapot),
159 description: "I'm a teapot",
160 uri: "somewhere",
161 want: "auth: \"418\" \"I'm a teapot\" \"somewhere\"",
162 },
163 }
164 for _, tt := range tests {
165 t.Run(tt.name, func(t *testing.T) {
166 r := &Error{
167 Response: tt.Response,
168 Body: tt.Body,
169 Err: tt.Err,
170 code: tt.code,
171 description: tt.description,
172 uri: tt.uri,
173 }
174 if got := r.Error(); got != tt.want {
175 t.Errorf("Error.Error() = %v, want %v", got, tt.want)
176 }
177 })
178 }
179 }
180
181 func TestNew2LOTokenProvider_JSONResponse(t *testing.T) {
182 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
183 w.Header().Set("Content-Type", "application/json")
184 w.Write([]byte(`{
185 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
186 "scope": "user",
187 "token_type": "bearer",
188 "expires_in": 3600
189 }`))
190 }))
191 defer ts.Close()
192
193 opts := &Options2LO{
194 Email: "aaa@example.com",
195 PrivateKey: fakePrivateKey,
196 TokenURL: ts.URL,
197 }
198 tp, err := New2LOTokenProvider(opts)
199 if err != nil {
200 t.Fatal(err)
201 }
202 tok, err := tp.Token(context.Background())
203 if err != nil {
204 t.Fatal(err)
205 }
206 if !tok.IsValid() {
207 t.Errorf("got invalid token: %v", tok)
208 }
209 if got, want := tok.Value, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
210 t.Errorf("access token = %q; want %q", got, want)
211 }
212 if got, want := tok.Type, "bearer"; got != want {
213 t.Errorf("token type = %q; want %q", got, want)
214 }
215 if got := tok.Expiry.IsZero(); got {
216 t.Errorf("token expiry = %v, want none", got)
217 }
218 scope := tok.Metadata["scope"].(string)
219 if got, want := scope, "user"; got != want {
220 t.Errorf("scope = %q; want %q", got, want)
221 }
222 }
223
224 func TestNew2LOTokenProvider_BadResponse(t *testing.T) {
225 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
226 w.Header().Set("Content-Type", "application/json")
227 w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
228 }))
229 defer ts.Close()
230
231 opts := &Options2LO{
232 Email: "aaa@example.com",
233 PrivateKey: fakePrivateKey,
234 TokenURL: ts.URL,
235 }
236 tp, err := New2LOTokenProvider(opts)
237 if err != nil {
238 t.Fatal(err)
239 }
240 tok, err := tp.Token(context.Background())
241 if err != nil {
242 t.Fatal(err)
243 }
244 if tok == nil {
245 t.Fatalf("got nil token; want token")
246 }
247 if tok.IsValid() {
248 t.Errorf("got invalid token: %v", tok)
249 }
250 if got, want := tok.Value, ""; got != want {
251 t.Errorf("access token = %q; want %q", got, want)
252 }
253 if got, want := tok.Type, "bearer"; got != want {
254 t.Errorf("token type = %q; want %q", got, want)
255 }
256 scope := tok.Metadata["scope"].(string)
257 if got, want := scope, "user"; got != want {
258 t.Errorf("token scope = %q; want %q", got, want)
259 }
260 }
261
262 func TestNew2LOTokenProvider_BadResponseType(t *testing.T) {
263 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
264 w.Header().Set("Content-Type", "application/json")
265 w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
266 }))
267 defer ts.Close()
268 opts := &Options2LO{
269 Email: "aaa@example.com",
270 PrivateKey: fakePrivateKey,
271 TokenURL: ts.URL,
272 }
273 tp, err := New2LOTokenProvider(opts)
274 if err != nil {
275 t.Fatal(err)
276 }
277 tok, err := tp.Token(context.Background())
278 if err == nil {
279 t.Error("got a token; expected error")
280 if got, want := tok.Value, ""; got != want {
281 t.Errorf("access token = %q; want %q", got, want)
282 }
283 }
284 }
285
286 func TestNew2LOTokenProvider_Assertion(t *testing.T) {
287 var assertion string
288 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
289 r.ParseForm()
290 assertion = r.Form.Get("assertion")
291
292 w.Header().Set("Content-Type", "application/json")
293 w.Write([]byte(`{
294 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
295 "scope": "user",
296 "token_type": "bearer",
297 "expires_in": 3600
298 }`))
299 }))
300 defer ts.Close()
301
302 opts := &Options2LO{
303 Email: "aaa@example.com",
304 PrivateKey: fakePrivateKey,
305 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
306 TokenURL: ts.URL,
307 }
308
309 tp, err := New2LOTokenProvider(opts)
310 if err != nil {
311 t.Fatal(err)
312 }
313 _, err = tp.Token(context.Background())
314 if err != nil {
315 t.Fatalf("Failed to fetch token: %v", err)
316 }
317
318 parts := strings.Split(assertion, ".")
319 if len(parts) != 3 {
320 t.Fatalf("assertion = %q; want 3 parts", assertion)
321 }
322 gotjson, err := base64.RawURLEncoding.DecodeString(parts[0])
323 if err != nil {
324 t.Fatalf("invalid token header; err = %v", err)
325 }
326
327 got := jwt.Header{}
328 if err := json.Unmarshal(gotjson, &got); err != nil {
329 t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err)
330 }
331
332 want := jwt.Header{
333 Algorithm: "RS256",
334 Type: "JWT",
335 KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
336 }
337 if got != want {
338 t.Errorf("access token header = %q; want %q", got, want)
339 }
340 }
341
342 func TestNew2LOTokenProvider_AssertionPayload(t *testing.T) {
343 var assertion string
344 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
345 r.ParseForm()
346 assertion = r.Form.Get("assertion")
347
348 w.Header().Set("Content-Type", "application/json")
349 w.Write([]byte(`{
350 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
351 "scope": "user",
352 "token_type": "bearer",
353 "expires_in": 3600
354 }`))
355 }))
356 defer ts.Close()
357
358 for _, opts := range []*Options2LO{
359 {
360 Email: "aaa1@example.com",
361 PrivateKey: fakePrivateKey,
362 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
363 TokenURL: ts.URL,
364 },
365 {
366 Email: "aaa2@example.com",
367 PrivateKey: fakePrivateKey,
368 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
369 TokenURL: ts.URL,
370 Audience: "https://example.com",
371 },
372 {
373 Email: "aaa2@example.com",
374 PrivateKey: fakePrivateKey,
375 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
376 TokenURL: ts.URL,
377 PrivateClaims: map[string]interface{}{
378 "private0": "claim0",
379 "private1": "claim1",
380 },
381 },
382 } {
383 t.Run(opts.Email, func(t *testing.T) {
384 tp, err := New2LOTokenProvider(opts)
385 if err != nil {
386 t.Fatal(err)
387 }
388 _, err = tp.Token(context.Background())
389 if err != nil {
390 t.Fatalf("Failed to fetch token: %v", err)
391 }
392
393 parts := strings.Split(assertion, ".")
394 if len(parts) != 3 {
395 t.Fatalf("assertion = %q; want 3 parts", assertion)
396 }
397 gotjson, err := base64.RawURLEncoding.DecodeString(parts[1])
398 if err != nil {
399 t.Fatalf("invalid token payload; err = %v", err)
400 }
401
402 claimSet := jwt.Claims{}
403 if err := json.Unmarshal(gotjson, &claimSet); err != nil {
404 t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err)
405 }
406
407 if got, want := claimSet.Iss, opts.Email; got != want {
408 t.Errorf("payload email = %q; want %q", got, want)
409 }
410 if got, want := claimSet.Scope, strings.Join(opts.Scopes, " "); got != want {
411 t.Errorf("payload scope = %q; want %q", got, want)
412 }
413 aud := opts.TokenURL
414 if opts.Audience != "" {
415 aud = opts.Audience
416 }
417 if got, want := claimSet.Aud, aud; got != want {
418 t.Errorf("payload audience = %q; want %q", got, want)
419 }
420 if got, want := claimSet.Sub, opts.Subject; got != want {
421 t.Errorf("payload subject = %q; want %q", got, want)
422 }
423 if len(opts.PrivateClaims) > 0 {
424 var got interface{}
425 if err := json.Unmarshal(gotjson, &got); err != nil {
426 t.Errorf("failed to parse payload; err = %q", err)
427 }
428 m := got.(map[string]interface{})
429 for v, k := range opts.PrivateClaims {
430 if !cmp.Equal(m[v], k) {
431 t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)
432 }
433 }
434 }
435 })
436 }
437 }
438
439 func TestNew2LOTokenProvider_TokenError(t *testing.T) {
440 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
441 w.Header().Set("Content-type", "application/json")
442 w.WriteHeader(http.StatusBadRequest)
443 w.Write([]byte(`{"error": "invalid_grant"}`))
444 }))
445 defer ts.Close()
446
447 opts := &Options2LO{
448 Email: "aaa@example.com",
449 PrivateKey: fakePrivateKey,
450 TokenURL: ts.URL,
451 }
452
453 tp, err := New2LOTokenProvider(opts)
454 if err != nil {
455 t.Fatal(err)
456 }
457 _, err = tp.Token(context.Background())
458 if err == nil {
459 t.Fatalf("got no error, expected one")
460 }
461 _, ok := err.(*Error)
462 if !ok {
463 t.Fatalf("got %T error, expected *Error", err)
464 }
465 expected := fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", "400", `{"error": "invalid_grant"}`)
466 if errStr := err.Error(); errStr != expected {
467 t.Fatalf("got %#v, expected %#v", errStr, expected)
468 }
469 }
470
471 func TestNew2LOTokenProvider_Validate(t *testing.T) {
472 tests := []struct {
473 name string
474 opts *Options2LO
475 }{
476 {
477 name: "missing options",
478 },
479 {
480 name: "missing email",
481 opts: &Options2LO{
482 PrivateKey: []byte("key"),
483 TokenURL: "url",
484 },
485 },
486 {
487 name: "missing key",
488 opts: &Options2LO{
489 Email: "email",
490 TokenURL: "url",
491 },
492 },
493 {
494 name: "missing URL",
495 opts: &Options2LO{
496 Email: "email",
497 PrivateKey: []byte("key"),
498 },
499 },
500 }
501 for _, tt := range tests {
502 t.Run(tt.name, func(t *testing.T) {
503 if _, err := New2LOTokenProvider(tt.opts); err == nil {
504 t.Error("got nil, want an error")
505 }
506 })
507 }
508 }
509
View as plain text