1
2
3
4
5 package impersonate
6
7 import (
8 "bytes"
9 "context"
10 "encoding/json"
11 "io"
12 "net/http"
13 "strings"
14 "testing"
15 "time"
16
17 "google.golang.org/api/option"
18 )
19
20 func TestTokenSource_user(t *testing.T) {
21 ctx := context.Background()
22 tests := []struct {
23 name string
24 targetPrincipal string
25 scopes []string
26 lifetime time.Duration
27 subject string
28 wantErr bool
29 universeDomain string
30 }{
31 {
32 name: "missing targetPrincipal",
33 wantErr: true,
34 },
35 {
36 name: "missing scopes",
37 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
38 wantErr: true,
39 },
40 {
41 name: "lifetime over max",
42 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
43 scopes: []string{"scope"},
44 lifetime: 13 * time.Hour,
45 wantErr: true,
46 },
47 {
48 name: "works",
49 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
50 scopes: []string{"scope"},
51 subject: "admin@example.com",
52 wantErr: false,
53 },
54 {
55 name: "universeDomain",
56 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
57 scopes: []string{"scope"},
58 subject: "admin@example.com",
59 wantErr: true,
60
61
62 universeDomain: "example.com",
63 },
64 }
65
66 for _, tt := range tests {
67 userTok := "user-token"
68 name := tt.name
69 t.Run(name, func(t *testing.T) {
70 client := &http.Client{
71 Transport: RoundTripFn(func(req *http.Request) *http.Response {
72 if strings.Contains(req.URL.Path, "signJwt") {
73 resp := signJWTResponse{
74 KeyID: "123",
75 SignedJWT: "jwt",
76 }
77 b, err := json.Marshal(&resp)
78 if err != nil {
79 t.Fatalf("unable to marshal response: %v", err)
80 }
81 return &http.Response{
82 StatusCode: 200,
83 Body: io.NopCloser(bytes.NewReader(b)),
84 Header: make(http.Header),
85 }
86 }
87 if strings.Contains(req.URL.Path, "/token") {
88 resp := exchangeTokenResponse{
89 AccessToken: userTok,
90 TokenType: "Bearer",
91 ExpiresIn: int64(time.Hour.Seconds()),
92 }
93 b, err := json.Marshal(&resp)
94 if err != nil {
95 t.Fatalf("unable to marshal response: %v", err)
96 }
97 return &http.Response{
98 StatusCode: 200,
99 Body: io.NopCloser(bytes.NewReader(b)),
100 Header: make(http.Header),
101 }
102 }
103 return nil
104 }),
105 }
106 ts, err := CredentialsTokenSource(ctx,
107 CredentialsConfig{
108 TargetPrincipal: tt.targetPrincipal,
109 Scopes: tt.scopes,
110 Lifetime: tt.lifetime,
111 Subject: tt.subject,
112 },
113 option.WithHTTPClient(client),
114 option.WithUniverseDomain(tt.universeDomain))
115 if tt.wantErr && err != nil {
116 return
117 }
118 if err != nil {
119 t.Fatal(err)
120 }
121 tok, err := ts.Token()
122 if err != nil {
123 t.Fatal(err)
124 }
125 if tok.AccessToken != userTok {
126 t.Fatalf("got %q, want %q", tok.AccessToken, userTok)
127 }
128 })
129 }
130 }
131
View as plain text