...

Source file src/cloud.google.com/go/auth/httptransport/httptransport_test.go

Documentation: cloud.google.com/go/auth/httptransport

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package httptransport
    16  
    17  import (
    18  	"context"
    19  	"net/http"
    20  	"net/http/httptest"
    21  	"strings"
    22  	"testing"
    23  
    24  	"cloud.google.com/go/auth"
    25  	"cloud.google.com/go/auth/credentials"
    26  	"cloud.google.com/go/auth/internal"
    27  	"github.com/google/go-cmp/cmp"
    28  )
    29  
    30  func TestAddAuthorizationMiddleware(t *testing.T) {
    31  	creds := auth.NewCredentials(&auth.CredentialsOptions{
    32  		TokenProvider: staticTP("fakeToken"),
    33  	})
    34  	tests := []struct {
    35  		name    string
    36  		client  *http.Client
    37  		creds   *auth.Credentials
    38  		wantErr bool
    39  		want    string
    40  	}{
    41  		{
    42  			name:    "missing both required fields",
    43  			wantErr: true,
    44  		},
    45  		{
    46  			name:    "missing client field",
    47  			creds:   creds,
    48  			wantErr: true,
    49  		},
    50  		{
    51  			name:    "missing creds field",
    52  			client:  internal.CloneDefaultClient(),
    53  			wantErr: true,
    54  		},
    55  		{
    56  			name:   "works",
    57  			client: internal.CloneDefaultClient(),
    58  			creds:  creds,
    59  			want:   "fakeToken",
    60  		},
    61  		{
    62  			name:   "works, no transport",
    63  			client: &http.Client{},
    64  			creds:  creds,
    65  			want:   "fakeToken",
    66  		},
    67  	}
    68  	for _, tt := range tests {
    69  		t.Run(tt.name, func(t *testing.T) {
    70  			err := AddAuthorizationMiddleware(tt.client, tt.creds)
    71  			if tt.wantErr && err == nil {
    72  				t.Fatalf("AddAuthorizationMiddleware() = nil, want error")
    73  			}
    74  			if tt.wantErr {
    75  				return
    76  			}
    77  			ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    78  				got := r.Header.Get("Authorization")
    79  				if !strings.Contains(got, tt.want) {
    80  					t.Errorf("got %q, want contain %q", got, tt.want)
    81  				}
    82  
    83  			}))
    84  			defer ts.Close()
    85  			tt.client.Get(ts.URL)
    86  		})
    87  	}
    88  }
    89  
    90  func TestNewClient_FailsValidation(t *testing.T) {
    91  	tests := []struct {
    92  		name string
    93  		opts *Options
    94  	}{
    95  		{
    96  			name: "missing options",
    97  		},
    98  		{
    99  			name: "has creds with disable options, tp",
   100  			opts: &Options{
   101  				DisableAuthentication: true,
   102  				Credentials: auth.NewCredentials(&auth.CredentialsOptions{
   103  					TokenProvider: staticTP("fakeToken"),
   104  				}),
   105  			},
   106  		},
   107  		{
   108  			name: "has creds with disable options, cred file",
   109  			opts: &Options{
   110  				DisableAuthentication: true,
   111  				DetectOpts: &credentials.DetectOptions{
   112  					CredentialsFile: "abc.123",
   113  				},
   114  			},
   115  		},
   116  		{
   117  			name: "has creds with disable options, cred json",
   118  			opts: &Options{
   119  				DisableAuthentication: true,
   120  				DetectOpts: &credentials.DetectOptions{
   121  					CredentialsJSON: []byte(`{"foo":"bar"}`),
   122  				},
   123  			},
   124  		},
   125  	}
   126  	for _, tt := range tests {
   127  		t.Run(tt.name, func(t *testing.T) {
   128  			_, err := NewClient(tt.opts)
   129  			if err == nil {
   130  				t.Fatal("NewClient() = _, nil, want error")
   131  			}
   132  		})
   133  	}
   134  }
   135  
   136  func TestDial_SkipValidation(t *testing.T) {
   137  	opts := &Options{
   138  		DisableAuthentication: true,
   139  		Credentials: auth.NewCredentials(&auth.CredentialsOptions{
   140  			TokenProvider: staticTP("fakeToken"),
   141  		}),
   142  	}
   143  	t.Run("invalid opts", func(t *testing.T) {
   144  		if err := opts.validate(); err == nil {
   145  			t.Fatalf("opts.validate() = nil, want error")
   146  		}
   147  	})
   148  
   149  	t.Run("skip invalid opts", func(t *testing.T) {
   150  		opts.InternalOptions = &InternalOptions{SkipValidation: true}
   151  		if err := opts.validate(); err != nil {
   152  			t.Fatalf("opts.validate() = %v, want nil", err)
   153  		}
   154  	})
   155  }
   156  
   157  func TestOptions_ResolveDetectOptions(t *testing.T) {
   158  	tests := []struct {
   159  		name string
   160  		in   *Options
   161  		want *credentials.DetectOptions
   162  	}{
   163  		{
   164  			name: "base",
   165  			in: &Options{
   166  				DetectOpts: &credentials.DetectOptions{
   167  					Scopes:          []string{"scope"},
   168  					CredentialsFile: "/path/to/a/file",
   169  				},
   170  			},
   171  			want: &credentials.DetectOptions{
   172  				Scopes:          []string{"scope"},
   173  				CredentialsFile: "/path/to/a/file",
   174  			},
   175  		},
   176  		{
   177  			name: "self-signed, with scope",
   178  			in: &Options{
   179  				InternalOptions: &InternalOptions{
   180  					EnableJWTWithScope: true,
   181  				},
   182  				DetectOpts: &credentials.DetectOptions{
   183  					Scopes:          []string{"scope"},
   184  					CredentialsFile: "/path/to/a/file",
   185  				},
   186  			},
   187  			want: &credentials.DetectOptions{
   188  				Scopes:           []string{"scope"},
   189  				CredentialsFile:  "/path/to/a/file",
   190  				UseSelfSignedJWT: true,
   191  			},
   192  		},
   193  		{
   194  			name: "self-signed, with aud",
   195  			in: &Options{
   196  				DetectOpts: &credentials.DetectOptions{
   197  					Audience:        "aud",
   198  					CredentialsFile: "/path/to/a/file",
   199  				},
   200  			},
   201  			want: &credentials.DetectOptions{
   202  				Audience:         "aud",
   203  				CredentialsFile:  "/path/to/a/file",
   204  				UseSelfSignedJWT: true,
   205  			},
   206  		},
   207  		{
   208  			name: "use default scopes",
   209  			in: &Options{
   210  				InternalOptions: &InternalOptions{
   211  					DefaultScopes:   []string{"default"},
   212  					DefaultAudience: "default",
   213  				},
   214  				DetectOpts: &credentials.DetectOptions{
   215  					CredentialsFile: "/path/to/a/file",
   216  				},
   217  			},
   218  			want: &credentials.DetectOptions{
   219  				Scopes:          []string{"default"},
   220  				CredentialsFile: "/path/to/a/file",
   221  			},
   222  		},
   223  		{
   224  			name: "don't use default scopes, scope provided",
   225  			in: &Options{
   226  				InternalOptions: &InternalOptions{
   227  					DefaultScopes:   []string{"default"},
   228  					DefaultAudience: "default",
   229  				},
   230  				DetectOpts: &credentials.DetectOptions{
   231  					Scopes:          []string{"non-default"},
   232  					CredentialsFile: "/path/to/a/file",
   233  				},
   234  			},
   235  			want: &credentials.DetectOptions{
   236  				Scopes:          []string{"non-default"},
   237  				CredentialsFile: "/path/to/a/file",
   238  			},
   239  		},
   240  		{
   241  			name: "don't use default scopes, aud provided",
   242  			in: &Options{
   243  				InternalOptions: &InternalOptions{
   244  					DefaultScopes:   []string{"default"},
   245  					DefaultAudience: "default",
   246  				},
   247  				DetectOpts: &credentials.DetectOptions{
   248  					Audience:        "non-default",
   249  					CredentialsFile: "/path/to/a/file",
   250  				},
   251  			},
   252  			want: &credentials.DetectOptions{
   253  				Audience:         "non-default",
   254  				CredentialsFile:  "/path/to/a/file",
   255  				UseSelfSignedJWT: true,
   256  			},
   257  		},
   258  		{
   259  			name: "use default aud",
   260  			in: &Options{
   261  				InternalOptions: &InternalOptions{
   262  					DefaultAudience: "default",
   263  				},
   264  				DetectOpts: &credentials.DetectOptions{
   265  					CredentialsFile: "/path/to/a/file",
   266  				},
   267  			},
   268  			want: &credentials.DetectOptions{
   269  				Audience:        "default",
   270  				CredentialsFile: "/path/to/a/file",
   271  			},
   272  		},
   273  	}
   274  	for _, tt := range tests {
   275  		t.Run(tt.name, func(t *testing.T) {
   276  			got := tt.in.resolveDetectOptions()
   277  			if diff := cmp.Diff(tt.want, got); diff != "" {
   278  				t.Errorf("mismatch (-want +got):\n%s", diff)
   279  			}
   280  		})
   281  	}
   282  }
   283  
   284  func TestNewClient_DetectedServiceAccount(t *testing.T) {
   285  	testQuota := "testquota"
   286  	wantHeader := "bar"
   287  	t.Setenv(internal.QuotaProjectEnvVar, testQuota)
   288  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   289  		if got := r.Header.Get("Authorization"); got == "" {
   290  			t.Errorf(`got "", want an auth token`)
   291  		}
   292  		if got := r.Header.Get("Foo"); got != wantHeader {
   293  			t.Errorf("got %q, want %q", got, wantHeader)
   294  		}
   295  		if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota {
   296  			t.Errorf("got %q, want %q", got, testQuota)
   297  		}
   298  	}))
   299  	defer ts.Close()
   300  	client, err := NewClient(&Options{
   301  		Headers: http.Header{"Foo": []string{wantHeader}},
   302  		InternalOptions: &InternalOptions{
   303  			DefaultEndpointTemplate: ts.URL,
   304  		},
   305  		DetectOpts: &credentials.DetectOptions{
   306  			Audience:         ts.URL,
   307  			CredentialsFile:  "../internal/testdata/sa.json",
   308  			UseSelfSignedJWT: true,
   309  		},
   310  	})
   311  	if err != nil {
   312  		t.Fatalf("NewClient() = %v", err)
   313  	}
   314  	req, err := http.NewRequest(http.MethodGet, ts.URL, nil)
   315  	if err != nil {
   316  		t.Fatal(err)
   317  	}
   318  	if _, err := client.Do(req); err != nil {
   319  		t.Fatalf("client.Get() = %v", err)
   320  	}
   321  }
   322  
   323  func TestNewClient_APIKey(t *testing.T) {
   324  	testQuota := "testquota"
   325  	apiKey := "thereisnospoon"
   326  	wantHeader := "bar"
   327  	t.Setenv(internal.QuotaProjectEnvVar, testQuota)
   328  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   329  		got := r.URL.Query().Get("key")
   330  		if got != apiKey {
   331  			t.Errorf("got %q, want %q", got, apiKey)
   332  		}
   333  		if got := r.Header.Get("Foo"); got != wantHeader {
   334  			t.Errorf("got %q, want %q", got, wantHeader)
   335  		}
   336  		if got := r.Header.Get(quotaProjectHeaderKey); got != testQuota {
   337  			t.Errorf("got %q, want %q", got, testQuota)
   338  		}
   339  	}))
   340  	defer ts.Close()
   341  	client, err := NewClient(&Options{
   342  		APIKey:  apiKey,
   343  		Headers: http.Header{"Foo": []string{wantHeader}},
   344  	})
   345  	if err != nil {
   346  		t.Fatalf("NewClient() = %v", err)
   347  	}
   348  	if _, err := client.Get(ts.URL); err != nil {
   349  		t.Fatalf("client.Get() = %v", err)
   350  	}
   351  }
   352  
   353  func TestNewClient_BaseRoundTripper(t *testing.T) {
   354  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   355  		got := r.Header.Get("Foo")
   356  		if want := "foo"; got != want {
   357  			t.Errorf("got %q, want %q", got, want)
   358  		}
   359  		got = r.Header.Get("Bar")
   360  		if want := "bar"; got != want {
   361  			t.Errorf("got %q, want %q", got, want)
   362  		}
   363  	}))
   364  	defer ts.Close()
   365  	client, err := NewClient(&Options{
   366  		BaseRoundTripper: &rt{key: "Bar", value: "bar"},
   367  		Headers:          http.Header{"Foo": []string{"foo"}},
   368  		APIKey:           "key",
   369  	})
   370  	if err != nil {
   371  		t.Fatalf("NewClient() = %v", err)
   372  	}
   373  	if _, err := client.Get(ts.URL); err != nil {
   374  		t.Fatalf("client.Get() = %v", err)
   375  	}
   376  }
   377  
   378  type staticTP string
   379  
   380  func (tp staticTP) Token(context.Context) (*auth.Token, error) {
   381  	return &auth.Token{
   382  		Value: string(tp),
   383  	}, nil
   384  }
   385  
   386  type rt struct {
   387  	key   string
   388  	value string
   389  }
   390  
   391  func (r *rt) RoundTrip(req *http.Request) (*http.Response, error) {
   392  	req2 := req.Clone(req.Context())
   393  	req2.Header.Add(r.key, r.value)
   394  	return http.DefaultTransport.RoundTrip(req2)
   395  }
   396  

View as plain text