// Copyright 2023 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package httptransport import ( "context" "net/http" "net/http/httptest" "strings" "testing" "cloud.google.com/go/auth" "cloud.google.com/go/auth/credentials" "cloud.google.com/go/auth/internal" "github.com/google/go-cmp/cmp" ) func TestAddAuthorizationMiddleware(t *testing.T) { creds := auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: staticTP("fakeToken"), }) tests := []struct { name string client *http.Client creds *auth.Credentials wantErr bool want string }{ { name: "missing both required fields", wantErr: true, }, { name: "missing client field", creds: creds, wantErr: true, }, { name: "missing creds field", client: internal.CloneDefaultClient(), wantErr: true, }, { name: "works", client: internal.CloneDefaultClient(), creds: creds, want: "fakeToken", }, { name: "works, no transport", client: &http.Client{}, creds: creds, want: "fakeToken", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := AddAuthorizationMiddleware(tt.client, tt.creds) if tt.wantErr && err == nil { t.Fatalf("AddAuthorizationMiddleware() = nil, want error") } if tt.wantErr { return } ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got := r.Header.Get("Authorization") if !strings.Contains(got, tt.want) { t.Errorf("got %q, want contain %q", got, tt.want) } })) defer ts.Close() tt.client.Get(ts.URL) }) } } func TestNewClient_FailsValidation(t *testing.T) { tests := []struct { name string opts *Options }{ { name: "missing options", }, { name: "has creds with disable options, tp", opts: &Options{ DisableAuthentication: true, Credentials: auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: staticTP("fakeToken"), }), }, }, { name: "has creds with disable options, cred file", opts: &Options{ DisableAuthentication: true, DetectOpts: &credentials.DetectOptions{ CredentialsFile: "abc.123", }, }, }, { name: "has creds with disable options, cred json", opts: &Options{ DisableAuthentication: true, DetectOpts: &credentials.DetectOptions{ CredentialsJSON: []byte(`{"foo":"bar"}`), }, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := NewClient(tt.opts) if err == nil { t.Fatal("NewClient() = _, nil, want error") } }) } } func TestDial_SkipValidation(t *testing.T) { opts := &Options{ DisableAuthentication: true, Credentials: auth.NewCredentials(&auth.CredentialsOptions{ TokenProvider: staticTP("fakeToken"), }), } t.Run("invalid opts", func(t *testing.T) { if err := opts.validate(); err == nil { t.Fatalf("opts.validate() = nil, want error") } }) t.Run("skip invalid opts", func(t *testing.T) { opts.InternalOptions = &InternalOptions{SkipValidation: true} if err := opts.validate(); err != nil { t.Fatalf("opts.validate() = %v, want nil", err) } }) } func TestOptions_ResolveDetectOptions(t *testing.T) { tests := []struct { name string in *Options want *credentials.DetectOptions }{ { name: "base", in: &Options{ DetectOpts: &credentials.DetectOptions{ Scopes: []string{"scope"}, CredentialsFile: "/path/to/a/file", }, }, want: &credentials.DetectOptions{ Scopes: []string{"scope"}, CredentialsFile: "/path/to/a/file", }, }, { name: "self-signed, with scope", in: &Options{ InternalOptions: &InternalOptions{ EnableJWTWithScope: true, }, DetectOpts: &credentials.DetectOptions{ Scopes: []string{"scope"}, CredentialsFile: "/path/to/a/file", }, }, want: &credentials.DetectOptions{ Scopes: []string{"scope"}, CredentialsFile: "/path/to/a/file", UseSelfSignedJWT: true, }, }, { name: "self-signed, with aud", in: &Options{ DetectOpts: &credentials.DetectOptions{ Audience: "aud", CredentialsFile: "/path/to/a/file", }, }, want: &credentials.DetectOptions{ Audience: "aud", CredentialsFile: "/path/to/a/file", UseSelfSignedJWT: true, }, }, { name: "use default scopes", in: &Options{ InternalOptions: &InternalOptions{ DefaultScopes: []string{"default"}, DefaultAudience: "default", }, DetectOpts: &credentials.DetectOptions{ CredentialsFile: "/path/to/a/file", }, }, want: &credentials.DetectOptions{ Scopes: []string{"default"}, CredentialsFile: "/path/to/a/file", }, }, { name: "don't use default scopes, scope provided", in: &Options{ InternalOptions: &InternalOptions{ DefaultScopes: []string{"default"}, DefaultAudience: "default", }, DetectOpts: &credentials.DetectOptions{ Scopes: []string{"non-default"}, CredentialsFile: "/path/to/a/file", }, }, want: &credentials.DetectOptions{ Scopes: []string{"non-default"}, CredentialsFile: "/path/to/a/file", }, }, { name: "don't use default scopes, aud provided", in: &Options{ InternalOptions: &InternalOptions{ DefaultScopes: []string{"default"}, DefaultAudience: "default", }, DetectOpts: &credentials.DetectOptions{ Audience: "non-default", CredentialsFile: "/path/to/a/file", }, }, want: &credentials.DetectOptions{ Audience: "non-default", CredentialsFile: "/path/to/a/file", UseSelfSignedJWT: true, }, }, { name: "use default aud", in: &Options{ InternalOptions: &InternalOptions{ DefaultAudience: "default", }, DetectOpts: &credentials.DetectOptions{ CredentialsFile: "/path/to/a/file", }, }, want: &credentials.DetectOptions{ Audience: "default", CredentialsFile: "/path/to/a/file", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := tt.in.resolveDetectOptions() if diff := cmp.Diff(tt.want, got); diff != "" { t.Errorf("mismatch (-want +got):\n%s", diff) } }) } } func TestNewClient_DetectedServiceAccount(t *testing.T) { testQuota := "testquota" wantHeader := "bar" t.Setenv(internal.QuotaProjectEnvVar, testQuota) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got := r.Header.Get("Authorization"); got == "" { t.Errorf(`got "", want an auth token`) } if got := r.Header.Get("Foo"); got != wantHeader { t.Errorf("got %q, want %q", got, wantHeader) } if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota { t.Errorf("got %q, want %q", got, testQuota) } })) defer ts.Close() client, err := NewClient(&Options{ Headers: http.Header{"Foo": []string{wantHeader}}, InternalOptions: &InternalOptions{ DefaultEndpointTemplate: ts.URL, }, DetectOpts: &credentials.DetectOptions{ Audience: ts.URL, CredentialsFile: "../internal/testdata/sa.json", UseSelfSignedJWT: true, }, }) if err != nil { t.Fatalf("NewClient() = %v", err) } req, err := http.NewRequest(http.MethodGet, ts.URL, nil) if err != nil { t.Fatal(err) } if _, err := client.Do(req); err != nil { t.Fatalf("client.Get() = %v", err) } } func TestNewClient_APIKey(t *testing.T) { testQuota := "testquota" apiKey := "thereisnospoon" wantHeader := "bar" t.Setenv(internal.QuotaProjectEnvVar, testQuota) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got := r.URL.Query().Get("key") if got != apiKey { t.Errorf("got %q, want %q", got, apiKey) } if got := r.Header.Get("Foo"); got != wantHeader { t.Errorf("got %q, want %q", got, wantHeader) } if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota { t.Errorf("got %q, want %q", got, testQuota) } })) defer ts.Close() client, err := NewClient(&Options{ APIKey: apiKey, Headers: http.Header{"Foo": []string{wantHeader}}, }) if err != nil { t.Fatalf("NewClient() = %v", err) } if _, err := client.Get(ts.URL); err != nil { t.Fatalf("client.Get() = %v", err) } } func TestNewClient_BaseRoundTripper(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { got := r.Header.Get("Foo") if want := "foo"; got != want { t.Errorf("got %q, want %q", got, want) } got = r.Header.Get("Bar") if want := "bar"; got != want { t.Errorf("got %q, want %q", got, want) } })) defer ts.Close() client, err := NewClient(&Options{ BaseRoundTripper: &rt{key: "Bar", value: "bar"}, Headers: http.Header{"Foo": []string{"foo"}}, APIKey: "key", }) if err != nil { t.Fatalf("NewClient() = %v", err) } if _, err := client.Get(ts.URL); err != nil { t.Fatalf("client.Get() = %v", err) } } type staticTP string func (tp staticTP) Token(context.Context) (*auth.Token, error) { return &auth.Token{ Value: string(tp), }, nil } type rt struct { key string value string } func (r *rt) RoundTrip(req *http.Request) (*http.Response, error) { req2 := req.Clone(req.Context()) req2.Header.Add(r.key, r.value) return http.DefaultTransport.RoundTrip(req2) }