1
18
19 package sts
20
21 import (
22 "bytes"
23 "context"
24 "crypto/x509"
25 "encoding/json"
26 "errors"
27 "fmt"
28 "io"
29 "net/http"
30 "net/http/httputil"
31 "strings"
32 "testing"
33 "time"
34
35 "github.com/google/go-cmp/cmp"
36
37 "google.golang.org/grpc/credentials"
38 icredentials "google.golang.org/grpc/internal/credentials"
39 "google.golang.org/grpc/internal/grpctest"
40 "google.golang.org/grpc/internal/testutils"
41 )
42
43 const (
44 requestedTokenType = "urn:ietf:params:oauth:token-type:access-token"
45 actorTokenPath = "/var/run/secrets/token.jwt"
46 actorTokenType = "urn:ietf:params:oauth:token-type:refresh_token"
47 actorTokenContents = "actorToken.jwt.contents"
48 accessTokenContents = "access_token"
49 subjectTokenPath = "/var/run/secrets/token.jwt"
50 subjectTokenType = "urn:ietf:params:oauth:token-type:id_token"
51 subjectTokenContents = "subjectToken.jwt.contents"
52 serviceURI = "http://localhost"
53 exampleResource = "https://backend.example.com/api"
54 exampleAudience = "example-backend-service"
55 testScope = "https://www.googleapis.com/auth/monitoring"
56 defaultTestTimeout = 1 * time.Second
57 defaultTestShortTimeout = 10 * time.Millisecond
58 )
59
60 var (
61 goodOptions = Options{
62 TokenExchangeServiceURI: serviceURI,
63 Audience: exampleAudience,
64 RequestedTokenType: requestedTokenType,
65 SubjectTokenPath: subjectTokenPath,
66 SubjectTokenType: subjectTokenType,
67 }
68 goodRequestParams = &requestParameters{
69 GrantType: tokenExchangeGrantType,
70 Audience: exampleAudience,
71 Scope: defaultCloudPlatformScope,
72 RequestedTokenType: requestedTokenType,
73 SubjectToken: subjectTokenContents,
74 SubjectTokenType: subjectTokenType,
75 }
76 goodMetadata = map[string]string{
77 "Authorization": fmt.Sprintf("Bearer %s", accessTokenContents),
78 }
79 )
80
81 type s struct {
82 grpctest.Tester
83 }
84
85 func Test(t *testing.T) {
86 grpctest.RunSubTests(t, s{})
87 }
88
89
90
91 type testAuthInfo struct {
92 credentials.CommonAuthInfo
93 }
94
95 func (ta testAuthInfo) AuthType() string {
96 return "testAuthInfo"
97 }
98
99 func createTestContext(ctx context.Context, s credentials.SecurityLevel) context.Context {
100 auth := &testAuthInfo{CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: s}}
101 ri := credentials.RequestInfo{
102 Method: "testInfo",
103 AuthInfo: auth,
104 }
105 return icredentials.NewRequestInfoContext(ctx, ri)
106 }
107
108
109
110 type errReader struct{}
111
112 func (r errReader) Read(b []byte) (n int, err error) {
113 return 0, errors.New("read error")
114 }
115
116
117
118
119 func makeGoodResponse() *http.Response {
120 respJSON, _ := json.Marshal(responseParameters{
121 AccessToken: accessTokenContents,
122 IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
123 TokenType: "Bearer",
124 ExpiresIn: 3600,
125 })
126 respBody := io.NopCloser(bytes.NewReader(respJSON))
127 return &http.Response{
128 Status: "200 OK",
129 StatusCode: http.StatusOK,
130 Body: respBody,
131 }
132 }
133
134
135 func overrideHTTPClientGood() (*testutils.FakeHTTPClient, func()) {
136 fc := &testutils.FakeHTTPClient{
137 ReqChan: testutils.NewChannel(),
138 RespChan: testutils.NewChannel(),
139 }
140 fc.RespChan.Send(makeGoodResponse())
141
142 origMakeHTTPDoer := makeHTTPDoer
143 makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
144 return fc, func() { makeHTTPDoer = origMakeHTTPDoer }
145 }
146
147
148 func overrideHTTPClient(fc *testutils.FakeHTTPClient) func() {
149 origMakeHTTPDoer := makeHTTPDoer
150 makeHTTPDoer = func(_ *x509.CertPool) httpDoer { return fc }
151 return func() { makeHTTPDoer = origMakeHTTPDoer }
152 }
153
154
155
156 func overrideSubjectTokenGood() func() {
157 origReadSubjectTokenFrom := readSubjectTokenFrom
158 readSubjectTokenFrom = func(path string) ([]byte, error) {
159 return []byte(subjectTokenContents), nil
160 }
161 return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
162 }
163
164
165 func overrideSubjectTokenError() func() {
166 origReadSubjectTokenFrom := readSubjectTokenFrom
167 readSubjectTokenFrom = func(path string) ([]byte, error) {
168 return nil, errors.New("error reading subject token")
169 }
170 return func() { readSubjectTokenFrom = origReadSubjectTokenFrom }
171 }
172
173
174
175 func overrideActorTokenGood() func() {
176 origReadActorTokenFrom := readActorTokenFrom
177 readActorTokenFrom = func(path string) ([]byte, error) {
178 return []byte(actorTokenContents), nil
179 }
180 return func() { readActorTokenFrom = origReadActorTokenFrom }
181 }
182
183
184 func overrideActorTokenError() func() {
185 origReadActorTokenFrom := readActorTokenFrom
186 readActorTokenFrom = func(path string) ([]byte, error) {
187 return nil, errors.New("error reading actor token")
188 }
189 return func() { readActorTokenFrom = origReadActorTokenFrom }
190 }
191
192
193
194 func compareRequest(gotRequest *http.Request, wantReqParams *requestParameters) error {
195 jsonBody, err := json.Marshal(wantReqParams)
196 if err != nil {
197 return err
198 }
199 wantReq, err := http.NewRequest("POST", serviceURI, bytes.NewBuffer(jsonBody))
200 if err != nil {
201 return fmt.Errorf("failed to create http request: %v", err)
202 }
203 wantReq.Header.Set("Content-Type", "application/json")
204
205 wantR, err := httputil.DumpRequestOut(wantReq, true)
206 if err != nil {
207 return err
208 }
209 gotR, err := httputil.DumpRequestOut(gotRequest, true)
210 if err != nil {
211 return err
212 }
213 if diff := cmp.Diff(string(wantR), string(gotR)); diff != "" {
214 return fmt.Errorf("sts request diff (-want +got):\n%s", diff)
215 }
216 return nil
217 }
218
219
220
221
222
223
224 func receiveAndCompareRequest(ReqChan *testutils.Channel, errCh chan error) {
225 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
226 defer cancel()
227
228 val, err := ReqChan.Receive(ctx)
229 if err != nil {
230 errCh <- err
231 return
232 }
233 req := val.(*http.Request)
234 if err := compareRequest(req, goodRequestParams); err != nil {
235 errCh <- err
236 return
237 }
238 errCh <- nil
239 }
240
241
242
243 func (s) TestGetRequestMetadataSuccess(t *testing.T) {
244 defer overrideSubjectTokenGood()()
245 fc, cancel := overrideHTTPClientGood()
246 defer cancel()
247
248 creds, err := NewCredentials(goodOptions)
249 if err != nil {
250 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
251 }
252
253 errCh := make(chan error, 1)
254 go receiveAndCompareRequest(fc.ReqChan, errCh)
255
256 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
257 defer cancel()
258
259 gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
260 if err != nil {
261 t.Fatalf("creds.GetRequestMetadata() = %v", err)
262 }
263 if !cmp.Equal(gotMetadata, goodMetadata) {
264 t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
265 }
266 if err := <-errCh; err != nil {
267 t.Fatal(err)
268 }
269
270
271
272
273
274 gotMetadata, err = creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
275 if err != nil {
276 t.Fatalf("creds.GetRequestMetadata() = %v", err)
277 }
278 if !cmp.Equal(gotMetadata, goodMetadata) {
279 t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
280 }
281 }
282
283
284
285
286 func (s) TestGetRequestMetadataBadSecurityLevel(t *testing.T) {
287 defer overrideSubjectTokenGood()()
288
289 creds, err := NewCredentials(goodOptions)
290 if err != nil {
291 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
292 }
293
294 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
295 defer cancel()
296 gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.IntegrityOnly), "")
297 if err == nil {
298 t.Fatalf("creds.GetRequestMetadata() succeeded with metadata %v, expected to fail", gotMetadata)
299 }
300 }
301
302
303
304
305 func (s) TestGetRequestMetadataCacheExpiry(t *testing.T) {
306 const expiresInSecs = 1
307 defer overrideSubjectTokenGood()()
308 fc := &testutils.FakeHTTPClient{
309 ReqChan: testutils.NewChannel(),
310 RespChan: testutils.NewChannel(),
311 }
312 defer overrideHTTPClient(fc)()
313
314 creds, err := NewCredentials(goodOptions)
315 if err != nil {
316 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
317 }
318
319
320
321
322
323 for i := 0; i < 2; i++ {
324 errCh := make(chan error, 1)
325 go receiveAndCompareRequest(fc.ReqChan, errCh)
326
327 respJSON, _ := json.Marshal(responseParameters{
328 AccessToken: accessTokenContents,
329 IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
330 TokenType: "Bearer",
331 ExpiresIn: expiresInSecs,
332 })
333 respBody := io.NopCloser(bytes.NewReader(respJSON))
334 resp := &http.Response{
335 Status: "200 OK",
336 StatusCode: http.StatusOK,
337 Body: respBody,
338 }
339 fc.RespChan.Send(resp)
340
341 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
342 defer cancel()
343 gotMetadata, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), "")
344 if err != nil {
345 t.Fatalf("creds.GetRequestMetadata() = %v", err)
346 }
347 if !cmp.Equal(gotMetadata, goodMetadata) {
348 t.Fatalf("creds.GetRequestMetadata() = %v, want %v", gotMetadata, goodMetadata)
349 }
350 if err := <-errCh; err != nil {
351 t.Fatal(err)
352 }
353 time.Sleep(expiresInSecs * time.Second)
354 }
355 }
356
357
358
359 func (s) TestGetRequestMetadataBadResponses(t *testing.T) {
360 tests := []struct {
361 name string
362 response *http.Response
363 }{
364 {
365 name: "bad JSON",
366 response: &http.Response{
367 Status: "200 OK",
368 StatusCode: http.StatusOK,
369 Body: io.NopCloser(strings.NewReader("not JSON")),
370 },
371 },
372 {
373 name: "no access token",
374 response: &http.Response{
375 Status: "200 OK",
376 StatusCode: http.StatusOK,
377 Body: io.NopCloser(strings.NewReader("{}")),
378 },
379 },
380 }
381
382 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
383 defer cancel()
384 for _, test := range tests {
385 t.Run(test.name, func(t *testing.T) {
386 defer overrideSubjectTokenGood()()
387
388 fc := &testutils.FakeHTTPClient{
389 ReqChan: testutils.NewChannel(),
390 RespChan: testutils.NewChannel(),
391 }
392 defer overrideHTTPClient(fc)()
393
394 creds, err := NewCredentials(goodOptions)
395 if err != nil {
396 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
397 }
398
399 errCh := make(chan error, 1)
400 go receiveAndCompareRequest(fc.ReqChan, errCh)
401
402 fc.RespChan.Send(test.response)
403 if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
404 t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
405 }
406 if err := <-errCh; err != nil {
407 t.Fatal(err)
408 }
409 })
410 }
411 }
412
413
414
415 func (s) TestGetRequestMetadataBadSubjectTokenRead(t *testing.T) {
416 defer overrideSubjectTokenError()()
417 fc, cancel := overrideHTTPClientGood()
418 defer cancel()
419
420 creds, err := NewCredentials(goodOptions)
421 if err != nil {
422 t.Fatalf("NewCredentials(%v) = %v", goodOptions, err)
423 }
424
425 errCh := make(chan error, 1)
426 go func() {
427 ctx, cancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
428 defer cancel()
429 if _, err := fc.ReqChan.Receive(ctx); err != context.DeadlineExceeded {
430 errCh <- err
431 return
432 }
433 errCh <- nil
434 }()
435
436 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
437 defer cancel()
438 if _, err := creds.GetRequestMetadata(createTestContext(ctx, credentials.PrivacyAndIntegrity), ""); err == nil {
439 t.Fatal("creds.GetRequestMetadata() succeeded when expected to fail")
440 }
441 if err := <-errCh; err != nil {
442 t.Fatal(err)
443 }
444 }
445
446 func (s) TestNewCredentials(t *testing.T) {
447 tests := []struct {
448 name string
449 opts Options
450 errSystemRoots bool
451 wantErr bool
452 }{
453 {
454 name: "invalid options - empty subjectTokenPath",
455 opts: Options{
456 TokenExchangeServiceURI: serviceURI,
457 },
458 wantErr: true,
459 },
460 {
461 name: "invalid system root certs",
462 opts: goodOptions,
463 errSystemRoots: true,
464 wantErr: true,
465 },
466 {
467 name: "good case",
468 opts: goodOptions,
469 },
470 }
471
472 for _, test := range tests {
473 t.Run(test.name, func(t *testing.T) {
474 if test.errSystemRoots {
475 oldSystemRoots := loadSystemCertPool
476 loadSystemCertPool = func() (*x509.CertPool, error) {
477 return nil, errors.New("failed to load system cert pool")
478 }
479 defer func() {
480 loadSystemCertPool = oldSystemRoots
481 }()
482 }
483
484 creds, err := NewCredentials(test.opts)
485 if (err != nil) != test.wantErr {
486 t.Fatalf("NewCredentials(%v) = %v, want %v", test.opts, err, test.wantErr)
487 }
488 if err == nil {
489 if !creds.RequireTransportSecurity() {
490 t.Errorf("creds.RequireTransportSecurity() returned false")
491 }
492 }
493 })
494 }
495 }
496
497 func (s) TestValidateOptions(t *testing.T) {
498 tests := []struct {
499 name string
500 opts Options
501 wantErrPrefix string
502 }{
503 {
504 name: "empty token exchange service URI",
505 opts: Options{},
506 wantErrPrefix: "empty token_exchange_service_uri in options",
507 },
508 {
509 name: "invalid URI",
510 opts: Options{
511 TokenExchangeServiceURI: "\tI'm a bad URI\n",
512 },
513 wantErrPrefix: "invalid control character in URL",
514 },
515 {
516 name: "unsupported scheme",
517 opts: Options{
518 TokenExchangeServiceURI: "unix:///path/to/socket",
519 },
520 wantErrPrefix: "scheme is not supported",
521 },
522 {
523 name: "empty subjectTokenPath",
524 opts: Options{
525 TokenExchangeServiceURI: serviceURI,
526 },
527 wantErrPrefix: "required field SubjectTokenPath is not specified",
528 },
529 {
530 name: "empty subjectTokenType",
531 opts: Options{
532 TokenExchangeServiceURI: serviceURI,
533 SubjectTokenPath: subjectTokenPath,
534 },
535 wantErrPrefix: "required field SubjectTokenType is not specified",
536 },
537 {
538 name: "good options",
539 opts: goodOptions,
540 },
541 }
542
543 for _, test := range tests {
544 t.Run(test.name, func(t *testing.T) {
545 err := validateOptions(test.opts)
546 if (err != nil) != (test.wantErrPrefix != "") {
547 t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
548 }
549 if err != nil && !strings.Contains(err.Error(), test.wantErrPrefix) {
550 t.Errorf("validateOptions(%v) = %v, want %v", test.opts, err, test.wantErrPrefix)
551 }
552 })
553 }
554 }
555
556 func (s) TestConstructRequest(t *testing.T) {
557 tests := []struct {
558 name string
559 opts Options
560 subjectTokenReadErr bool
561 actorTokenReadErr bool
562 wantReqParams *requestParameters
563 wantErr bool
564 }{
565 {
566 name: "subject token read failure",
567 subjectTokenReadErr: true,
568 opts: goodOptions,
569 wantErr: true,
570 },
571 {
572 name: "actor token read failure",
573 actorTokenReadErr: true,
574 opts: Options{
575 TokenExchangeServiceURI: serviceURI,
576 Audience: exampleAudience,
577 RequestedTokenType: requestedTokenType,
578 SubjectTokenPath: subjectTokenPath,
579 SubjectTokenType: subjectTokenType,
580 ActorTokenPath: actorTokenPath,
581 ActorTokenType: actorTokenType,
582 },
583 wantErr: true,
584 },
585 {
586 name: "default cloud platform scope",
587 opts: goodOptions,
588 wantReqParams: goodRequestParams,
589 },
590 {
591 name: "all good",
592 opts: Options{
593 TokenExchangeServiceURI: serviceURI,
594 Resource: exampleResource,
595 Audience: exampleAudience,
596 Scope: testScope,
597 RequestedTokenType: requestedTokenType,
598 SubjectTokenPath: subjectTokenPath,
599 SubjectTokenType: subjectTokenType,
600 ActorTokenPath: actorTokenPath,
601 ActorTokenType: actorTokenType,
602 },
603 wantReqParams: &requestParameters{
604 GrantType: tokenExchangeGrantType,
605 Resource: exampleResource,
606 Audience: exampleAudience,
607 Scope: testScope,
608 RequestedTokenType: requestedTokenType,
609 SubjectToken: subjectTokenContents,
610 SubjectTokenType: subjectTokenType,
611 ActorToken: actorTokenContents,
612 ActorTokenType: actorTokenType,
613 },
614 },
615 }
616
617 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
618 defer cancel()
619 for _, test := range tests {
620 t.Run(test.name, func(t *testing.T) {
621 if test.subjectTokenReadErr {
622 defer overrideSubjectTokenError()()
623 } else {
624 defer overrideSubjectTokenGood()()
625 }
626
627 if test.actorTokenReadErr {
628 defer overrideActorTokenError()()
629 } else {
630 defer overrideActorTokenGood()()
631 }
632
633 gotRequest, err := constructRequest(ctx, test.opts)
634 if (err != nil) != test.wantErr {
635 t.Fatalf("constructRequest(%v) = %v, wantErr: %v", test.opts, err, test.wantErr)
636 }
637 if test.wantErr {
638 return
639 }
640 if err := compareRequest(gotRequest, test.wantReqParams); err != nil {
641 t.Fatal(err)
642 }
643 })
644 }
645 }
646
647 func (s) TestSendRequest(t *testing.T) {
648 defer overrideSubjectTokenGood()()
649 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
650 defer cancel()
651 req, err := constructRequest(ctx, goodOptions)
652 if err != nil {
653 t.Fatal(err)
654 }
655
656 tests := []struct {
657 name string
658 resp *http.Response
659 respErr error
660 wantErr bool
661 }{
662 {
663 name: "client error",
664 respErr: errors.New("http.Client.Do failed"),
665 wantErr: true,
666 },
667 {
668 name: "bad response body",
669 resp: &http.Response{
670 Status: "200 OK",
671 StatusCode: http.StatusOK,
672 Body: io.NopCloser(errReader{}),
673 },
674 wantErr: true,
675 },
676 {
677 name: "nonOK status code",
678 resp: &http.Response{
679 Status: "400 BadRequest",
680 StatusCode: http.StatusBadRequest,
681 Body: io.NopCloser(strings.NewReader("")),
682 },
683 wantErr: true,
684 },
685 {
686 name: "good case",
687 resp: makeGoodResponse(),
688 },
689 }
690
691 for _, test := range tests {
692 t.Run(test.name, func(t *testing.T) {
693 client := &testutils.FakeHTTPClient{
694 ReqChan: testutils.NewChannel(),
695 RespChan: testutils.NewChannel(),
696 Err: test.respErr,
697 }
698 client.RespChan.Send(test.resp)
699 _, err := sendRequest(client, req)
700 if (err != nil) != test.wantErr {
701 t.Errorf("sendRequest(%v) = %v, wantErr: %v", req, err, test.wantErr)
702 }
703 })
704 }
705 }
706
707 func (s) TestTokenInfoFromResponse(t *testing.T) {
708 noAccessToken, _ := json.Marshal(responseParameters{
709 IssuedTokenType: "urn:ietf:params:oauth:token-type:access_token",
710 TokenType: "Bearer",
711 ExpiresIn: 3600,
712 })
713 goodResponse, _ := json.Marshal(responseParameters{
714 IssuedTokenType: requestedTokenType,
715 AccessToken: accessTokenContents,
716 TokenType: "Bearer",
717 ExpiresIn: 3600,
718 })
719
720 tests := []struct {
721 name string
722 respBody []byte
723 wantTokenInfo *tokenInfo
724 wantErr bool
725 }{
726 {
727 name: "bad JSON",
728 respBody: []byte("not JSON"),
729 wantErr: true,
730 },
731 {
732 name: "empty response",
733 respBody: []byte(""),
734 wantErr: true,
735 },
736 {
737 name: "non-empty response with no access token",
738 respBody: noAccessToken,
739 wantErr: true,
740 },
741 {
742 name: "good response",
743 respBody: goodResponse,
744 wantTokenInfo: &tokenInfo{
745 tokenType: "Bearer",
746 token: accessTokenContents,
747 },
748 },
749 }
750
751 for _, test := range tests {
752 t.Run(test.name, func(t *testing.T) {
753 gotTokenInfo, err := tokenInfoFromResponse(test.respBody)
754 if (err != nil) != test.wantErr {
755 t.Fatalf("tokenInfoFromResponse(%+v) = %v, wantErr: %v", test.respBody, err, test.wantErr)
756 }
757 if test.wantErr {
758 return
759 }
760
761
762 if gotTokenInfo.tokenType != test.wantTokenInfo.tokenType || gotTokenInfo.token != test.wantTokenInfo.token {
763 t.Errorf("tokenInfoFromResponse(%+v) = %+v, want: %+v", test.respBody, gotTokenInfo, test.wantTokenInfo)
764 }
765 })
766 }
767 }
768
View as plain text