1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package impersonate
16
17 import (
18 "bytes"
19 "context"
20 "encoding/json"
21 "io"
22 "net/http"
23 "strings"
24 "testing"
25 "time"
26
27 "cloud.google.com/go/auth/internal"
28 "cloud.google.com/go/auth/internal/jwt"
29 )
30
31 func TestNewCredentials_user(t *testing.T) {
32 ctx := context.Background()
33 tests := []struct {
34 name string
35 targetPrincipal string
36 scopes []string
37 lifetime time.Duration
38 subject string
39 wantErr bool
40 universeDomain string
41 }{
42 {
43 name: "missing targetPrincipal",
44 wantErr: true,
45 },
46 {
47 name: "missing scopes",
48 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
49 wantErr: true,
50 },
51 {
52 name: "lifetime over max",
53 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
54 scopes: []string{"scope"},
55 lifetime: 13 * time.Hour,
56 wantErr: true,
57 },
58 {
59 name: "works",
60 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
61 scopes: []string{"scope"},
62 subject: "admin@example.com",
63 wantErr: false,
64 },
65 {
66 name: "universeDomain",
67 targetPrincipal: "foo@project-id.iam.gserviceaccount.com",
68 scopes: []string{"scope"},
69 subject: "admin@example.com",
70 wantErr: true,
71
72
73 universeDomain: "example.com",
74 },
75 }
76
77 for _, tt := range tests {
78 userTok := "user-token"
79 name := tt.name
80 t.Run(name, func(t *testing.T) {
81 client := &http.Client{
82 Transport: RoundTripFn(func(req *http.Request) *http.Response {
83 defer req.Body.Close()
84 if strings.Contains(req.URL.Path, "signJwt") {
85 b, err := io.ReadAll(req.Body)
86 if err != nil {
87 t.Error(err)
88 }
89 var r signJWTRequest
90 if err := json.Unmarshal(b, &r); err != nil {
91 t.Error(err)
92 }
93 jwtPayload := map[string]interface{}{}
94 if err := json.Unmarshal([]byte(r.Payload), &jwtPayload); err != nil {
95 t.Error(err)
96 }
97 if got, want := jwtPayload["iss"].(string), tt.targetPrincipal; got != want {
98 t.Errorf("got %q, want %q", got, want)
99 }
100 if got, want := jwtPayload["sub"].(string), tt.subject; got != want {
101 t.Errorf("got %q, want %q", got, want)
102 }
103 if got, want := jwtPayload["scope"].(string), strings.Join(tt.scopes, ","); got != want {
104 t.Errorf("got %q, want %q", got, want)
105 }
106
107 resp := signJWTResponse{
108 KeyID: "123",
109 SignedJWT: jwt.HeaderType,
110 }
111 b, err = json.Marshal(&resp)
112 if err != nil {
113 t.Fatalf("unable to marshal response: %v", err)
114 }
115 return &http.Response{
116 StatusCode: 200,
117 Body: io.NopCloser(bytes.NewReader(b)),
118 Header: make(http.Header),
119 }
120 }
121 if strings.Contains(req.URL.Path, "/token") {
122 resp := exchangeTokenResponse{
123 AccessToken: userTok,
124 TokenType: internal.TokenTypeBearer,
125 ExpiresIn: int64(time.Hour.Seconds()),
126 }
127 b, err := json.Marshal(&resp)
128 if err != nil {
129 t.Fatalf("unable to marshal response: %v", err)
130 }
131 return &http.Response{
132 StatusCode: 200,
133 Body: io.NopCloser(bytes.NewReader(b)),
134 Header: make(http.Header),
135 }
136 }
137 return nil
138 }),
139 }
140 ts, err := NewCredentials(&CredentialsOptions{
141 TargetPrincipal: tt.targetPrincipal,
142 Scopes: tt.scopes,
143 Lifetime: tt.lifetime,
144 Subject: tt.subject,
145 Client: client,
146 UniverseDomain: tt.universeDomain,
147 })
148 if tt.wantErr && err != nil {
149 return
150 }
151 if err != nil {
152 t.Fatal(err)
153 }
154 tok, err := ts.Token(ctx)
155 if err != nil {
156 t.Fatal(err)
157 }
158 if tok.Value != userTok {
159 t.Fatalf("got %q, want %q", tok.Value, userTok)
160 }
161 })
162 }
163 }
164
View as plain text