1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package httptransport
16
17 import (
18 "context"
19 "net/http"
20 "net/http/httptest"
21 "strings"
22 "testing"
23
24 "cloud.google.com/go/auth"
25 "cloud.google.com/go/auth/credentials"
26 "cloud.google.com/go/auth/internal"
27 "github.com/google/go-cmp/cmp"
28 )
29
30 func TestAddAuthorizationMiddleware(t *testing.T) {
31 creds := auth.NewCredentials(&auth.CredentialsOptions{
32 TokenProvider: staticTP("fakeToken"),
33 })
34 tests := []struct {
35 name string
36 client *http.Client
37 creds *auth.Credentials
38 wantErr bool
39 want string
40 }{
41 {
42 name: "missing both required fields",
43 wantErr: true,
44 },
45 {
46 name: "missing client field",
47 creds: creds,
48 wantErr: true,
49 },
50 {
51 name: "missing creds field",
52 client: internal.CloneDefaultClient(),
53 wantErr: true,
54 },
55 {
56 name: "works",
57 client: internal.CloneDefaultClient(),
58 creds: creds,
59 want: "fakeToken",
60 },
61 {
62 name: "works, no transport",
63 client: &http.Client{},
64 creds: creds,
65 want: "fakeToken",
66 },
67 }
68 for _, tt := range tests {
69 t.Run(tt.name, func(t *testing.T) {
70 err := AddAuthorizationMiddleware(tt.client, tt.creds)
71 if tt.wantErr && err == nil {
72 t.Fatalf("AddAuthorizationMiddleware() = nil, want error")
73 }
74 if tt.wantErr {
75 return
76 }
77 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
78 got := r.Header.Get("Authorization")
79 if !strings.Contains(got, tt.want) {
80 t.Errorf("got %q, want contain %q", got, tt.want)
81 }
82
83 }))
84 defer ts.Close()
85 tt.client.Get(ts.URL)
86 })
87 }
88 }
89
90 func TestNewClient_FailsValidation(t *testing.T) {
91 tests := []struct {
92 name string
93 opts *Options
94 }{
95 {
96 name: "missing options",
97 },
98 {
99 name: "has creds with disable options, tp",
100 opts: &Options{
101 DisableAuthentication: true,
102 Credentials: auth.NewCredentials(&auth.CredentialsOptions{
103 TokenProvider: staticTP("fakeToken"),
104 }),
105 },
106 },
107 {
108 name: "has creds with disable options, cred file",
109 opts: &Options{
110 DisableAuthentication: true,
111 DetectOpts: &credentials.DetectOptions{
112 CredentialsFile: "abc.123",
113 },
114 },
115 },
116 {
117 name: "has creds with disable options, cred json",
118 opts: &Options{
119 DisableAuthentication: true,
120 DetectOpts: &credentials.DetectOptions{
121 CredentialsJSON: []byte(`{"foo":"bar"}`),
122 },
123 },
124 },
125 }
126 for _, tt := range tests {
127 t.Run(tt.name, func(t *testing.T) {
128 _, err := NewClient(tt.opts)
129 if err == nil {
130 t.Fatal("NewClient() = _, nil, want error")
131 }
132 })
133 }
134 }
135
136 func TestDial_SkipValidation(t *testing.T) {
137 opts := &Options{
138 DisableAuthentication: true,
139 Credentials: auth.NewCredentials(&auth.CredentialsOptions{
140 TokenProvider: staticTP("fakeToken"),
141 }),
142 }
143 t.Run("invalid opts", func(t *testing.T) {
144 if err := opts.validate(); err == nil {
145 t.Fatalf("opts.validate() = nil, want error")
146 }
147 })
148
149 t.Run("skip invalid opts", func(t *testing.T) {
150 opts.InternalOptions = &InternalOptions{SkipValidation: true}
151 if err := opts.validate(); err != nil {
152 t.Fatalf("opts.validate() = %v, want nil", err)
153 }
154 })
155 }
156
157 func TestOptions_ResolveDetectOptions(t *testing.T) {
158 tests := []struct {
159 name string
160 in *Options
161 want *credentials.DetectOptions
162 }{
163 {
164 name: "base",
165 in: &Options{
166 DetectOpts: &credentials.DetectOptions{
167 Scopes: []string{"scope"},
168 CredentialsFile: "/path/to/a/file",
169 },
170 },
171 want: &credentials.DetectOptions{
172 Scopes: []string{"scope"},
173 CredentialsFile: "/path/to/a/file",
174 },
175 },
176 {
177 name: "self-signed, with scope",
178 in: &Options{
179 InternalOptions: &InternalOptions{
180 EnableJWTWithScope: true,
181 },
182 DetectOpts: &credentials.DetectOptions{
183 Scopes: []string{"scope"},
184 CredentialsFile: "/path/to/a/file",
185 },
186 },
187 want: &credentials.DetectOptions{
188 Scopes: []string{"scope"},
189 CredentialsFile: "/path/to/a/file",
190 UseSelfSignedJWT: true,
191 },
192 },
193 {
194 name: "self-signed, with aud",
195 in: &Options{
196 DetectOpts: &credentials.DetectOptions{
197 Audience: "aud",
198 CredentialsFile: "/path/to/a/file",
199 },
200 },
201 want: &credentials.DetectOptions{
202 Audience: "aud",
203 CredentialsFile: "/path/to/a/file",
204 UseSelfSignedJWT: true,
205 },
206 },
207 {
208 name: "use default scopes",
209 in: &Options{
210 InternalOptions: &InternalOptions{
211 DefaultScopes: []string{"default"},
212 DefaultAudience: "default",
213 },
214 DetectOpts: &credentials.DetectOptions{
215 CredentialsFile: "/path/to/a/file",
216 },
217 },
218 want: &credentials.DetectOptions{
219 Scopes: []string{"default"},
220 CredentialsFile: "/path/to/a/file",
221 },
222 },
223 {
224 name: "don't use default scopes, scope provided",
225 in: &Options{
226 InternalOptions: &InternalOptions{
227 DefaultScopes: []string{"default"},
228 DefaultAudience: "default",
229 },
230 DetectOpts: &credentials.DetectOptions{
231 Scopes: []string{"non-default"},
232 CredentialsFile: "/path/to/a/file",
233 },
234 },
235 want: &credentials.DetectOptions{
236 Scopes: []string{"non-default"},
237 CredentialsFile: "/path/to/a/file",
238 },
239 },
240 {
241 name: "don't use default scopes, aud provided",
242 in: &Options{
243 InternalOptions: &InternalOptions{
244 DefaultScopes: []string{"default"},
245 DefaultAudience: "default",
246 },
247 DetectOpts: &credentials.DetectOptions{
248 Audience: "non-default",
249 CredentialsFile: "/path/to/a/file",
250 },
251 },
252 want: &credentials.DetectOptions{
253 Audience: "non-default",
254 CredentialsFile: "/path/to/a/file",
255 UseSelfSignedJWT: true,
256 },
257 },
258 {
259 name: "use default aud",
260 in: &Options{
261 InternalOptions: &InternalOptions{
262 DefaultAudience: "default",
263 },
264 DetectOpts: &credentials.DetectOptions{
265 CredentialsFile: "/path/to/a/file",
266 },
267 },
268 want: &credentials.DetectOptions{
269 Audience: "default",
270 CredentialsFile: "/path/to/a/file",
271 },
272 },
273 }
274 for _, tt := range tests {
275 t.Run(tt.name, func(t *testing.T) {
276 got := tt.in.resolveDetectOptions()
277 if diff := cmp.Diff(tt.want, got); diff != "" {
278 t.Errorf("mismatch (-want +got):\n%s", diff)
279 }
280 })
281 }
282 }
283
284 func TestNewClient_DetectedServiceAccount(t *testing.T) {
285 testQuota := "testquota"
286 wantHeader := "bar"
287 t.Setenv(internal.QuotaProjectEnvVar, testQuota)
288 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
289 if got := r.Header.Get("Authorization"); got == "" {
290 t.Errorf(`got "", want an auth token`)
291 }
292 if got := r.Header.Get("Foo"); got != wantHeader {
293 t.Errorf("got %q, want %q", got, wantHeader)
294 }
295 if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota {
296 t.Errorf("got %q, want %q", got, testQuota)
297 }
298 }))
299 defer ts.Close()
300 client, err := NewClient(&Options{
301 Headers: http.Header{"Foo": []string{wantHeader}},
302 InternalOptions: &InternalOptions{
303 DefaultEndpointTemplate: ts.URL,
304 },
305 DetectOpts: &credentials.DetectOptions{
306 Audience: ts.URL,
307 CredentialsFile: "../internal/testdata/sa.json",
308 UseSelfSignedJWT: true,
309 },
310 })
311 if err != nil {
312 t.Fatalf("NewClient() = %v", err)
313 }
314 req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
315 if err != nil {
316 t.Fatal(err)
317 }
318 if _, err := client.Do(req); err != nil {
319 t.Fatalf("client.Get() = %v", err)
320 }
321 }
322
323 func TestNewClient_APIKey(t *testing.T) {
324 testQuota := "testquota"
325 apiKey := "thereisnospoon"
326 wantHeader := "bar"
327 t.Setenv(internal.QuotaProjectEnvVar, testQuota)
328 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
329 got := r.URL.Query().Get("key")
330 if got != apiKey {
331 t.Errorf("got %q, want %q", got, apiKey)
332 }
333 if got := r.Header.Get("Foo"); got != wantHeader {
334 t.Errorf("got %q, want %q", got, wantHeader)
335 }
336 if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota {
337 t.Errorf("got %q, want %q", got, testQuota)
338 }
339 }))
340 defer ts.Close()
341 client, err := NewClient(&Options{
342 APIKey: apiKey,
343 Headers: http.Header{"Foo": []string{wantHeader}},
344 })
345 if err != nil {
346 t.Fatalf("NewClient() = %v", err)
347 }
348 if _, err := client.Get(ts.URL); err != nil {
349 t.Fatalf("client.Get() = %v", err)
350 }
351 }
352
353 func TestNewClient_BaseRoundTripper(t *testing.T) {
354 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
355 got := r.Header.Get("Foo")
356 if want := "foo"; got != want {
357 t.Errorf("got %q, want %q", got, want)
358 }
359 got = r.Header.Get("Bar")
360 if want := "bar"; got != want {
361 t.Errorf("got %q, want %q", got, want)
362 }
363 }))
364 defer ts.Close()
365 client, err := NewClient(&Options{
366 BaseRoundTripper: &rt{key: "Bar", value: "bar"},
367 Headers: http.Header{"Foo": []string{"foo"}},
368 APIKey: "key",
369 })
370 if err != nil {
371 t.Fatalf("NewClient() = %v", err)
372 }
373 if _, err := client.Get(ts.URL); err != nil {
374 t.Fatalf("client.Get() = %v", err)
375 }
376 }
377
378 type staticTP string
379
380 func (tp staticTP) Token(context.Context) (*auth.Token, error) {
381 return &auth.Token{
382 Value: string(tp),
383 }, nil
384 }
385
386 type rt struct {
387 key string
388 value string
389 }
390
391 func (r *rt) RoundTrip(req *http.Request) (*http.Response, error) {
392 req2 := req.Clone(req.Context())
393 req2.Header.Add(r.key, r.value)
394 return http.DefaultTransport.RoundTrip(req2)
395 }
396
View as plain text