// Copyright 2023 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package externalaccountauthorizeduser import ( "context" "encoding/json" "errors" "io/ioutil" "net/http" "net/http/httptest" "testing" "time" "golang.org/x/oauth2" "golang.org/x/oauth2/google/internal/stsexchange" ) const expiryDelta = 10 * time.Second var ( expiry = time.Unix(234852, 0) testNow = func() time.Time { return expiry } testValid = func(t oauth2.Token) bool { return t.AccessToken != "" && !t.Expiry.Round(0).Add(-expiryDelta).Before(testNow()) } ) type testRefreshTokenServer struct { URL string Authorization string ContentType string Body string ResponsePayload *stsexchange.Response Response string server *httptest.Server } func TestExernalAccountAuthorizedUser_JustToken(t *testing.T) { config := &Config{ Token: "AAAAAAA", Expiry: now().Add(time.Hour), } ts, err := config.TokenSource(context.Background()) if err != nil { t.Fatalf("Error getting token source: %v", err) } token, err := ts.Token() if err != nil { t.Fatalf("Error retrieving Token: %v", err) } if got, want := token.AccessToken, "AAAAAAA"; got != want { t.Fatalf("Unexpected access token, got %v, want %v", got, want) } } func TestExernalAccountAuthorizedUser_TokenRefreshWithRefreshTokenInRespondse(t *testing.T) { server := &testRefreshTokenServer{ URL: "/", Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", ContentType: "application/x-www-form-urlencoded", Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB", ResponsePayload: &stsexchange.Response{ ExpiresIn: 3600, AccessToken: "AAAAAAA", RefreshToken: "CCCCCCC", }, } url, err := server.run(t) if err != nil { t.Fatalf("Error starting server") } defer server.close(t) config := &Config{ RefreshToken: "BBBBBBBBB", TokenURL: url, ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", } ts, err := config.TokenSource(context.Background()) if err != nil { t.Fatalf("Error getting token source: %v", err) } token, err := ts.Token() if err != nil { t.Fatalf("Error retrieving Token: %v", err) } if got, want := token.AccessToken, "AAAAAAA"; got != want { t.Fatalf("Unexpected access token, got %v, want %v", got, want) } if config.RefreshToken != "CCCCCCC" { t.Fatalf("Refresh token not updated") } } func TestExernalAccountAuthorizedUser_MinimumFieldsRequiredForRefresh(t *testing.T) { server := &testRefreshTokenServer{ URL: "/", Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", ContentType: "application/x-www-form-urlencoded", Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB", ResponsePayload: &stsexchange.Response{ ExpiresIn: 3600, AccessToken: "AAAAAAA", }, } url, err := server.run(t) if err != nil { t.Fatalf("Error starting server") } defer server.close(t) config := &Config{ RefreshToken: "BBBBBBBBB", TokenURL: url, ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", } ts, err := config.TokenSource(context.Background()) if err != nil { t.Fatalf("Error getting token source: %v", err) } token, err := ts.Token() if err != nil { t.Fatalf("Error retrieving Token: %v", err) } if got, want := token.AccessToken, "AAAAAAA"; got != want { t.Fatalf("Unexpected access token, got %v, want %v", got, want) } } func TestExternalAccountAuthorizedUser_MissingRefreshFields(t *testing.T) { server := &testRefreshTokenServer{ URL: "/", Authorization: "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=", ContentType: "application/x-www-form-urlencoded", Body: "grant_type=refresh_token&refresh_token=BBBBBBBBB", ResponsePayload: &stsexchange.Response{ ExpiresIn: 3600, AccessToken: "AAAAAAA", }, } url, err := server.run(t) if err != nil { t.Fatalf("Error starting server") } defer server.close(t) testCases := []struct { name string config Config }{ { name: "empty config", config: Config{}, }, { name: "missing refresh token", config: Config{ TokenURL: url, ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", }, }, { name: "missing token url", config: Config{ RefreshToken: "BBBBBBBBB", ClientID: "CLIENT_ID", ClientSecret: "CLIENT_SECRET", }, }, { name: "missing client id", config: Config{ RefreshToken: "BBBBBBBBB", TokenURL: url, ClientSecret: "CLIENT_SECRET", }, }, { name: "missing client secrect", config: Config{ RefreshToken: "BBBBBBBBB", TokenURL: url, ClientID: "CLIENT_ID", }, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { expectErrMsg := "oauth2/google: Token should be created with fields to make it valid (`token` and `expiry`), or fields to allow it to refresh (`refresh_token`, `token_url`, `client_id`, `client_secret`)." _, err := tc.config.TokenSource((context.Background())) if err == nil { t.Fatalf("Expected error, but received none") } if got := err.Error(); got != expectErrMsg { t.Fatalf("Unexpected error, got %v, want %v", got, expectErrMsg) } }) } } func (trts *testRefreshTokenServer) run(t *testing.T) (string, error) { t.Helper() if trts.server != nil { return "", errors.New("Server is already running") } trts.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if got, want := r.URL.String(), trts.URL; got != want { t.Errorf("URL.String(): got %v but want %v", got, want) } headerAuth := r.Header.Get("Authorization") if got, want := headerAuth, trts.Authorization; got != want { t.Errorf("got %v but want %v", got, want) } headerContentType := r.Header.Get("Content-Type") if got, want := headerContentType, trts.ContentType; got != want { t.Errorf("got %v but want %v", got, want) } body, err := ioutil.ReadAll(r.Body) if err != nil { t.Fatalf("Failed reading request body: %s.", err) } if got, want := string(body), trts.Body; got != want { t.Errorf("Unexpected exchange payload: got %v but want %v", got, want) } w.Header().Set("Content-Type", "application/json") if trts.ResponsePayload != nil { content, err := json.Marshal(trts.ResponsePayload) if err != nil { t.Fatalf("unable to marshall response JSON") } w.Write(content) } else { w.Write([]byte(trts.Response)) } })) return trts.server.URL, nil } func (trts *testRefreshTokenServer) close(t *testing.T) error { t.Helper() if trts.server == nil { return errors.New("No server is running") } trts.server.Close() trts.server = nil return nil }