...

Source file src/golang.org/x/oauth2/google/internal/externalaccountauthorizeduser/externalaccountauthorizeduser_test.go

Documentation: golang.org/x/oauth2/google/internal/externalaccountauthorizeduser

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package externalaccountauthorizeduser
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"errors"
    11  	"io/ioutil"
    12  	"net/http"
    13  	"net/http/httptest"
    14  	"testing"
    15  	"time"
    16  
    17  	"golang.org/x/oauth2"
    18  	"golang.org/x/oauth2/google/internal/stsexchange"
    19  )
    20  
    21  const expiryDelta = 10 * time.Second
    22  
    23  var (
    24  	expiry    = time.Unix(234852, 0)
    25  	testNow   = func() time.Time { return expiry }
    26  	testValid = func(t oauth2.Token) bool {
    27  		return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow())
    28  	}
    29  )
    30  
    31  type testRefreshTokenServer struct {
    32  	URL             string
    33  	Authorization   string
    34  	ContentType     string
    35  	Body            string
    36  	ResponsePayload *stsexchange.Response
    37  	Response        string
    38  	server          *httptest.Server
    39  }
    40  
    41  func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) {
    42  	config := &Config{
    43  		Token:  "AAAAAAA",
    44  		Expiry: now().Add(time.Hour),
    45  	}
    46  	ts, err := config.TokenSource(context.Background())
    47  	if err != nil {
    48  		t.Fatalf("Error getting token source: %v", err)
    49  	}
    50  
    51  	token, err := ts.Token()
    52  	if err != nil {
    53  		t.Fatalf("Error retrieving Token: %v", err)
    54  	}
    55  	if got, want := token.AccessToken, "AAAAAAA"; got != want {
    56  		t.Fatalf("Unexpected access token, got %v, want %v", got, want)
    57  	}
    58  }
    59  
    60  func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) {
    61  	server := &testRefreshTokenServer{
    62  		URL:           "/",
    63  		Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
    64  		ContentType:   "application/x-www-form-urlencoded",
    65  		Body:          "grant_type=refresh_token&refresh_token=BBBBBBBBB",
    66  		ResponsePayload: &stsexchange.Response{
    67  			ExpiresIn:    3600,
    68  			AccessToken:  "AAAAAAA",
    69  			RefreshToken: "CCCCCCC",
    70  		},
    71  	}
    72  
    73  	url, err := server.run(t)
    74  	if err != nil {
    75  		t.Fatalf("Error starting server")
    76  	}
    77  	defer server.close(t)
    78  
    79  	config := &Config{
    80  		RefreshToken: "BBBBBBBBB",
    81  		TokenURL:     url,
    82  		ClientID:     "CLIENT_ID",
    83  		ClientSecret: "CLIENT_SECRET",
    84  	}
    85  	ts, err := config.TokenSource(context.Background())
    86  	if err != nil {
    87  		t.Fatalf("Error getting token source: %v", err)
    88  	}
    89  
    90  	token, err := ts.Token()
    91  	if err != nil {
    92  		t.Fatalf("Error retrieving Token: %v", err)
    93  	}
    94  	if got, want := token.AccessToken, "AAAAAAA"; got != want {
    95  		t.Fatalf("Unexpected access token, got %v, want %v", got, want)
    96  	}
    97  	if config.RefreshToken != "CCCCCCC" {
    98  		t.Fatalf("Refresh token not updated")
    99  	}
   100  }
   101  
   102  func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) {
   103  	server := &testRefreshTokenServer{
   104  		URL:           "/",
   105  		Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
   106  		ContentType:   "application/x-www-form-urlencoded",
   107  		Body:          "grant_type=refresh_token&refresh_token=BBBBBBBBB",
   108  		ResponsePayload: &stsexchange.Response{
   109  			ExpiresIn:   3600,
   110  			AccessToken: "AAAAAAA",
   111  		},
   112  	}
   113  
   114  	url, err := server.run(t)
   115  	if err != nil {
   116  		t.Fatalf("Error starting server")
   117  	}
   118  	defer server.close(t)
   119  
   120  	config := &Config{
   121  		RefreshToken: "BBBBBBBBB",
   122  		TokenURL:     url,
   123  		ClientID:     "CLIENT_ID",
   124  		ClientSecret: "CLIENT_SECRET",
   125  	}
   126  	ts, err := config.TokenSource(context.Background())
   127  	if err != nil {
   128  		t.Fatalf("Error getting token source: %v", err)
   129  	}
   130  
   131  	token, err := ts.Token()
   132  	if err != nil {
   133  		t.Fatalf("Error retrieving Token: %v", err)
   134  	}
   135  	if got, want := token.AccessToken, "AAAAAAA"; got != want {
   136  		t.Fatalf("Unexpected access token, got %v, want %v", got, want)
   137  	}
   138  }
   139  
   140  func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) {
   141  	server := &testRefreshTokenServer{
   142  		URL:           "/",
   143  		Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=",
   144  		ContentType:   "application/x-www-form-urlencoded",
   145  		Body:          "grant_type=refresh_token&refresh_token=BBBBBBBBB",
   146  		ResponsePayload: &stsexchange.Response{
   147  			ExpiresIn:   3600,
   148  			AccessToken: "AAAAAAA",
   149  		},
   150  	}
   151  
   152  	url, err := server.run(t)
   153  	if err != nil {
   154  		t.Fatalf("Error starting server")
   155  	}
   156  	defer server.close(t)
   157  	testCases := []struct {
   158  		name   string
   159  		config Config
   160  	}{
   161  		{
   162  			name:   "empty config",
   163  			config: Config{},
   164  		},
   165  		{
   166  			name: "missing refresh token",
   167  			config: Config{
   168  				TokenURL:     url,
   169  				ClientID:     "CLIENT_ID",
   170  				ClientSecret: "CLIENT_SECRET",
   171  			},
   172  		},
   173  		{
   174  			name: "missing token url",
   175  			config: Config{
   176  				RefreshToken: "BBBBBBBBB",
   177  				ClientID:     "CLIENT_ID",
   178  				ClientSecret: "CLIENT_SECRET",
   179  			},
   180  		},
   181  		{
   182  			name: "missing client id",
   183  			config: Config{
   184  				RefreshToken: "BBBBBBBBB",
   185  				TokenURL:     url,
   186  				ClientSecret: "CLIENT_SECRET",
   187  			},
   188  		},
   189  		{
   190  			name: "missing client secrect",
   191  			config: Config{
   192  				RefreshToken: "BBBBBBBBB",
   193  				TokenURL:     url,
   194  				ClientID:     "CLIENT_ID",
   195  			},
   196  		},
   197  	}
   198  	for _, tc := range testCases {
   199  		t.Run(tc.name, func(t *testing.T) {
   200  
   201  			expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)."
   202  			_, err := tc.config.TokenSource((context.Background()))
   203  			if err == nil {
   204  				t.Fatalf("Expected error, but received none")
   205  			}
   206  			if got := err.Error(); got != expectErrMsg {
   207  				t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg)
   208  			}
   209  		})
   210  	}
   211  }
   212  
   213  func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) {
   214  	t.Helper()
   215  	if trts.server != nil {
   216  		return "", errors.New("Server is already running")
   217  	}
   218  	trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   219  		if got, want := r.URL.String(), trts.URL; got != want {
   220  			t.Errorf("URL.String(): got %v but want %v", got, want)
   221  		}
   222  		headerAuth := r.Header.Get("Authorization")
   223  		if got, want := headerAuth, trts.Authorization; got != want {
   224  			t.Errorf("got %v but want %v", got, want)
   225  		}
   226  		headerContentType := r.Header.Get("Content-Type")
   227  		if got, want := headerContentType, trts.ContentType; got != want {
   228  			t.Errorf("got %v but want %v", got, want)
   229  		}
   230  		body, err := ioutil.ReadAll(r.Body)
   231  		if err != nil {
   232  			t.Fatalf("Failed reading request body: %s.", err)
   233  		}
   234  		if got, want := string(body), trts.Body; got != want {
   235  			t.Errorf("Unexpected exchange payload: got %v but want %v", got, want)
   236  		}
   237  		w.Header().Set("Content-Type", "application/json")
   238  		if trts.ResponsePayload != nil {
   239  			content, err := json.Marshal(trts.ResponsePayload)
   240  			if err != nil {
   241  				t.Fatalf("unable to marshall response JSON")
   242  			}
   243  			w.Write(content)
   244  		} else {
   245  			w.Write([]byte(trts.Response))
   246  		}
   247  	}))
   248  	return trts.server.URL, nil
   249  }
   250  
   251  func (trts *testRefreshTokenServer) close(t *testing.T) error {
   252  	t.Helper()
   253  	if trts.server == nil {
   254  		return errors.New("No server is running")
   255  	}
   256  	trts.server.Close()
   257  	trts.server = nil
   258  	return nil
   259  }
   260  

View as plain text