1
2
3
4
5 package impersonate
6
7 import (
8 "bytes"
9 "context"
10 "encoding/json"
11 "io"
12 "net/http"
13 "strings"
14 "testing"
15 "time"
16
17 "google.golang.org/api/option"
18 )
19
20 func TestTokenSource_serviceAccount(t *testing.T) {
21 ctx := context.Background()
22 tests := []struct {
23 name string
24 config CredentialsConfig
25 opts option.ClientOption
26 wantErr error
27 }{
28 {
29 name: "missing targetPrincipal",
30 wantErr: errMissingTargetPrincipal,
31 },
32 {
33 name: "missing scopes",
34 config: CredentialsConfig{
35 TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
36 },
37 wantErr: errMissingScopes,
38 },
39 {
40 name: "lifetime over max",
41 config: CredentialsConfig{
42 TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
43 Scopes: []string{"scope"},
44 Lifetime: 13 * time.Hour,
45 },
46 wantErr: errLifetimeOverMax,
47 },
48 {
49 name: "works",
50 config: CredentialsConfig{
51 TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
52 Scopes: []string{"scope"},
53 },
54 wantErr: nil,
55 },
56 {
57 name: "universe domain",
58 config: CredentialsConfig{
59 TargetPrincipal: "foo@project-id.iam.gserviceaccount.com",
60 Scopes: []string{"scope"},
61 Subject: "admin@example.com",
62 },
63 opts: option.WithUniverseDomain("example.com"),
64 wantErr: errUniverseNotSupportedDomainWideDelegation,
65 },
66 }
67
68 for _, tt := range tests {
69 name := tt.name
70 t.Run(name, func(t *testing.T) {
71 saTok := "sa-token"
72 client := &http.Client{
73 Transport: RoundTripFn(func(req *http.Request) *http.Response {
74 if strings.Contains(req.URL.Path, "generateAccessToken") {
75 resp := generateAccessTokenResp{
76 AccessToken: saTok,
77 ExpireTime: time.Now().Format(time.RFC3339),
78 }
79 b, err := json.Marshal(&resp)
80 if err != nil {
81 t.Fatalf("unable to marshal response: %v", err)
82 }
83 return &http.Response{
84 StatusCode: 200,
85 Body: io.NopCloser(bytes.NewReader(b)),
86 Header: http.Header{},
87 }
88 }
89 return nil
90 }),
91 }
92 opts := []option.ClientOption{
93 option.WithHTTPClient(client),
94 }
95 if tt.opts != nil {
96 opts = append(opts, tt.opts)
97 }
98 ts, err := CredentialsTokenSource(ctx, tt.config, opts...)
99
100 if err != nil {
101 if err != tt.wantErr {
102 t.Fatalf("%s: err: %v", tt.name, err)
103 }
104 } else {
105 tok, err := ts.Token()
106 if err != nil {
107 t.Fatal(err)
108 }
109 if tok.AccessToken != saTok {
110 t.Fatalf("got %q, want %q", tok.AccessToken, saTok)
111 }
112 }
113 })
114 }
115 }
116
117 type RoundTripFn func(req *http.Request) *http.Response
118
119 func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }
120
View as plain text