...

Source file src/cloud.google.com/go/auth/credentials/internal/externalaccountuser/externalaccountuser_test.go

Documentation: cloud.google.com/go/auth/credentials/internal/externalaccountuser

     1  // Copyright 2023 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package externalaccountuser
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"io"
    21  	"net/http"
    22  	"net/http/httptest"
    23  	"testing"
    24  
    25  	"cloud.google.com/go/auth/credentials/internal/stsexchange"
    26  	"cloud.google.com/go/auth/internal"
    27  )
    28  
    29  type testTokenServer struct {
    30  	URL             string
    31  	Authorization   string
    32  	ContentType     string
    33  	Body            string
    34  	ResponsePayload *stsexchange.TokenResponse
    35  	Response        string
    36  	server          *httptest.Server
    37  }
    38  
    39  func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInResponse(t *testing.T) {
    40  	s := &testTokenServer{
    41  		URL:           "/",
    42  		Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
    43  		ContentType:   "application/x-www-form-urlencoded",
    44  		Body:          "grant_type=refresh_token&refresh_token=BBBBBBBBB",
    45  		ResponsePayload: &stsexchange.TokenResponse{
    46  			ExpiresIn:    3600,
    47  			AccessToken:  "AAAAAAA",
    48  			RefreshToken: "CCCCCCC",
    49  		},
    50  	}
    51  
    52  	s.startTestServer(t)
    53  	defer s.server.Close()
    54  
    55  	opts := &Options{
    56  		RefreshToken: "BBBBBBBBB",
    57  		TokenURL:     s.server.URL,
    58  		ClientID:     "CLIENT_ID",
    59  		ClientSecret: "CLIENT_SECRET",
    60  		Client:       internal.CloneDefaultClient(),
    61  	}
    62  	tp, err := NewTokenProvider(opts)
    63  	if err != nil {
    64  		t.Fatalf("NewTokenProvider() =  %v", err)
    65  	}
    66  
    67  	token, err := tp.Token(context.Background())
    68  	if err != nil {
    69  		t.Fatalf("Token() = %v", err)
    70  	}
    71  	if got, want := token.Value, "AAAAAAA"; got != want {
    72  		t.Fatalf("got %v, want %v", got, want)
    73  	}
    74  	if got, want := opts.RefreshToken, "CCCCCCC"; got != want {
    75  		t.Fatalf("got %v, want %v", got, want)
    76  	}
    77  }
    78  
    79  func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
    80  	s := &testTokenServer{
    81  		URL:           "/",
    82  		Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
    83  		ContentType:   "application/x-www-form-urlencoded",
    84  		Body:          "grant_type=refresh_token&refresh_token=BBBBBBBBB",
    85  		ResponsePayload: &stsexchange.TokenResponse{
    86  			ExpiresIn:   3600,
    87  			AccessToken: "AAAAAAA",
    88  		},
    89  	}
    90  
    91  	s.startTestServer(t)
    92  	defer s.server.Close()
    93  
    94  	opts := &Options{
    95  		RefreshToken: "BBBBBBBBB",
    96  		TokenURL:     s.server.URL,
    97  		ClientID:     "CLIENT_ID",
    98  		ClientSecret: "CLIENT_SECRET",
    99  		Client:       internal.CloneDefaultClient(),
   100  	}
   101  	ts, err := NewTokenProvider(opts)
   102  	if err != nil {
   103  		t.Fatalf("NewTokenProvider() = %v", err)
   104  	}
   105  
   106  	token, err := ts.Token(context.Background())
   107  	if err != nil {
   108  		t.Fatalf("Token() = %v", err)
   109  	}
   110  	if got, want := token.Value, "AAAAAAA"; got != want {
   111  		t.Fatalf("got %v, want %v", got, want)
   112  	}
   113  }
   114  
   115  func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
   116  	s := &testTokenServer{
   117  		URL:           "/",
   118  		Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
   119  		ContentType:   "application/x-www-form-urlencoded",
   120  		Body:          "grant_type=refresh_token&refresh_token=BBBBBBBBB",
   121  		ResponsePayload: &stsexchange.TokenResponse{
   122  			ExpiresIn:   3600,
   123  			AccessToken: "AAAAAAA",
   124  		},
   125  	}
   126  
   127  	s.startTestServer(t)
   128  	defer s.server.Close()
   129  	testCases := []struct {
   130  		name string
   131  		opts *Options
   132  	}{
   133  		{
   134  			name: "empty config",
   135  			opts: &Options{},
   136  		},
   137  		{
   138  			name: "missing refresh token",
   139  			opts: &Options{
   140  				TokenURL:     s.server.URL,
   141  				ClientID:     "CLIENT_ID",
   142  				ClientSecret: "CLIENT_SECRET",
   143  			},
   144  		},
   145  		{
   146  			name: "missing token url",
   147  			opts: &Options{
   148  				RefreshToken: "BBBBBBBBB",
   149  				ClientID:     "CLIENT_ID",
   150  				ClientSecret: "CLIENT_SECRET",
   151  			},
   152  		},
   153  		{
   154  			name: "missing client id",
   155  			opts: &Options{
   156  				RefreshToken: "BBBBBBBBB",
   157  				TokenURL:     s.server.URL,
   158  				ClientSecret: "CLIENT_SECRET",
   159  			},
   160  		},
   161  		{
   162  			name: "missing client secrect",
   163  			opts: &Options{
   164  				RefreshToken: "BBBBBBBBB",
   165  				TokenURL:     s.server.URL,
   166  				ClientID:     "CLIENT_ID",
   167  			},
   168  		},
   169  	}
   170  	for _, tt := range testCases {
   171  		t.Run(tt.name, func(t *testing.T) {
   172  			if _, err := NewTokenProvider(tt.opts); err == nil {
   173  				t.Fatalf("got nil, want an error")
   174  			}
   175  		})
   176  	}
   177  }
   178  
   179  func (s *testTokenServer) startTestServer(t *testing.T) {
   180  	t.Helper()
   181  	s.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   182  		if got, want := r.URL.String(), s.URL; got != want {
   183  			t.Errorf("got %v, want %v", got, want)
   184  		}
   185  		headerAuth := r.Header.Get("Authorization")
   186  		if got, want := headerAuth, s.Authorization; got != want {
   187  			t.Errorf("got %v, want %v", got, want)
   188  		}
   189  		headerContentType := r.Header.Get("Content-Type")
   190  		if got, want := headerContentType, s.ContentType; got != want {
   191  			t.Errorf("got %v. want %v", got, want)
   192  		}
   193  		body, err := io.ReadAll(r.Body)
   194  		if err != nil {
   195  			t.Error(err)
   196  		}
   197  		if got, want := string(body), s.Body; got != want {
   198  			t.Errorf("got %q, want %q", got, want)
   199  		}
   200  		w.Header().Set("Content-Type", "application/json")
   201  		if s.ResponsePayload != nil {
   202  			content, err := json.Marshal(s.ResponsePayload)
   203  			if err != nil {
   204  				t.Error(err)
   205  			}
   206  			w.Write(content)
   207  		} else {
   208  			w.Write([]byte(s.Response))
   209  		}
   210  	}))
   211  }
   212  

View as plain text