...

Source file src/golang.org/x/oauth2/oauth2_test.go

Documentation: golang.org/x/oauth2

     1  // Copyright 2014 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package oauth2
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"testing"
    17  	"time"
    18  )
    19  
    20  type mockTransport struct {
    21  	rt func(req *http.Request) (resp *http.Response, err error)
    22  }
    23  
    24  func (t *mockTransport) RoundTrip(req *http.Request) (resp *http.Response, err error) {
    25  	return t.rt(req)
    26  }
    27  
    28  func newConf(url string) *Config {
    29  	return &Config{
    30  		ClientID:     "CLIENT_ID",
    31  		ClientSecret: "CLIENT_SECRET",
    32  		RedirectURL:  "REDIRECT_URL",
    33  		Scopes:       []string{"scope1", "scope2"},
    34  		Endpoint: Endpoint{
    35  			AuthURL:  url + "/auth",
    36  			TokenURL: url + "/token",
    37  		},
    38  	}
    39  }
    40  
    41  func TestAuthCodeURL(t *testing.T) {
    42  	conf := newConf("server")
    43  	url := conf.AuthCodeURL("foo", AccessTypeOffline, ApprovalForce)
    44  	const want = "server/auth?access_type=offline&client_id=CLIENT_ID&prompt=consent&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=foo"
    45  	if got := url; got != want {
    46  		t.Errorf("got auth code URL = %q; want %q", got, want)
    47  	}
    48  }
    49  
    50  func TestAuthCodeURL_CustomParam(t *testing.T) {
    51  	conf := newConf("server")
    52  	param := SetAuthURLParam("foo", "bar")
    53  	url := conf.AuthCodeURL("baz", param)
    54  	const want = "server/auth?client_id=CLIENT_ID&foo=bar&redirect_uri=REDIRECT_URL&response_type=code&scope=scope1+scope2&state=baz"
    55  	if got := url; got != want {
    56  		t.Errorf("got auth code = %q; want %q", got, want)
    57  	}
    58  }
    59  
    60  func TestAuthCodeURL_Optional(t *testing.T) {
    61  	conf := &Config{
    62  		ClientID: "CLIENT_ID",
    63  		Endpoint: Endpoint{
    64  			AuthURL:  "/auth-url",
    65  			TokenURL: "/token-url",
    66  		},
    67  	}
    68  	url := conf.AuthCodeURL("")
    69  	const want = "/auth-url?client_id=CLIENT_ID&response_type=code"
    70  	if got := url; got != want {
    71  		t.Fatalf("got auth code = %q; want %q", got, want)
    72  	}
    73  }
    74  
    75  func TestURLUnsafeClientConfig(t *testing.T) {
    76  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    77  		if got, want := r.Header.Get("Authorization"), "Basic Q0xJRU5UX0lEJTNGJTNGOkNMSUVOVF9TRUNSRVQlM0YlM0Y="; got != want {
    78  			t.Errorf("Authorization header = %q; want %q", got, want)
    79  		}
    80  
    81  		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
    82  		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
    83  	}))
    84  	defer ts.Close()
    85  	conf := newConf(ts.URL)
    86  	conf.ClientID = "CLIENT_ID??"
    87  	conf.ClientSecret = "CLIENT_SECRET??"
    88  	_, err := conf.Exchange(context.Background(), "exchange-code")
    89  	if err != nil {
    90  		t.Error(err)
    91  	}
    92  }
    93  
    94  func TestExchangeRequest(t *testing.T) {
    95  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    96  		if r.URL.String() != "/token" {
    97  			t.Errorf("Unexpected exchange request URL %q", r.URL)
    98  		}
    99  		headerAuth := r.Header.Get("Authorization")
   100  		if want := "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="; headerAuth != want {
   101  			t.Errorf("Unexpected authorization header %q, want %q", headerAuth, want)
   102  		}
   103  		headerContentType := r.Header.Get("Content-Type")
   104  		if headerContentType != "application/x-www-form-urlencoded" {
   105  			t.Errorf("Unexpected Content-Type header %q", headerContentType)
   106  		}
   107  		body, err := ioutil.ReadAll(r.Body)
   108  		if err != nil {
   109  			t.Errorf("Failed reading request body: %s.", err)
   110  		}
   111  		if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
   112  			t.Errorf("Unexpected exchange payload; got %q", body)
   113  		}
   114  		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
   115  		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
   116  	}))
   117  	defer ts.Close()
   118  	conf := newConf(ts.URL)
   119  	tok, err := conf.Exchange(context.Background(), "exchange-code")
   120  	if err != nil {
   121  		t.Error(err)
   122  	}
   123  	if !tok.Valid() {
   124  		t.Fatalf("Token invalid. Got: %#v", tok)
   125  	}
   126  	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
   127  		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
   128  	}
   129  	if tok.TokenType != "bearer" {
   130  		t.Errorf("Unexpected token type, %#v.", tok.TokenType)
   131  	}
   132  	scope := tok.Extra("scope")
   133  	if scope != "user" {
   134  		t.Errorf("Unexpected value for scope: %v", scope)
   135  	}
   136  }
   137  
   138  func TestExchangeRequest_CustomParam(t *testing.T) {
   139  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   140  		if r.URL.String() != "/token" {
   141  			t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
   142  		}
   143  		headerAuth := r.Header.Get("Authorization")
   144  		if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
   145  			t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
   146  		}
   147  		headerContentType := r.Header.Get("Content-Type")
   148  		if headerContentType != "application/x-www-form-urlencoded" {
   149  			t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
   150  		}
   151  		body, err := ioutil.ReadAll(r.Body)
   152  		if err != nil {
   153  			t.Errorf("Failed reading request body: %s.", err)
   154  		}
   155  		if string(body) != "code=exchange-code&foo=bar&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
   156  			t.Errorf("Unexpected exchange payload, %v is found.", string(body))
   157  		}
   158  		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
   159  		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
   160  	}))
   161  	defer ts.Close()
   162  	conf := newConf(ts.URL)
   163  
   164  	param := SetAuthURLParam("foo", "bar")
   165  	tok, err := conf.Exchange(context.Background(), "exchange-code", param)
   166  	if err != nil {
   167  		t.Error(err)
   168  	}
   169  	if !tok.Valid() {
   170  		t.Fatalf("Token invalid. Got: %#v", tok)
   171  	}
   172  	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
   173  		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
   174  	}
   175  	if tok.TokenType != "bearer" {
   176  		t.Errorf("Unexpected token type, %#v.", tok.TokenType)
   177  	}
   178  	scope := tok.Extra("scope")
   179  	if scope != "user" {
   180  		t.Errorf("Unexpected value for scope: %v", scope)
   181  	}
   182  }
   183  
   184  func TestExchangeRequest_JSONResponse(t *testing.T) {
   185  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   186  		if r.URL.String() != "/token" {
   187  			t.Errorf("Unexpected exchange request URL, %v is found.", r.URL)
   188  		}
   189  		headerAuth := r.Header.Get("Authorization")
   190  		if headerAuth != "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ=" {
   191  			t.Errorf("Unexpected authorization header, %v is found.", headerAuth)
   192  		}
   193  		headerContentType := r.Header.Get("Content-Type")
   194  		if headerContentType != "application/x-www-form-urlencoded" {
   195  			t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
   196  		}
   197  		body, err := ioutil.ReadAll(r.Body)
   198  		if err != nil {
   199  			t.Errorf("Failed reading request body: %s.", err)
   200  		}
   201  		if string(body) != "code=exchange-code&grant_type=authorization_code&redirect_uri=REDIRECT_URL" {
   202  			t.Errorf("Unexpected exchange payload, %v is found.", string(body))
   203  		}
   204  		w.Header().Set("Content-Type", "application/json")
   205  		w.Write([]byte(`{"access_token": "90d64460d14870c08c81352a05dedd3465940a7c", "scope": "user", "token_type": "bearer", "expires_in": 86400}`))
   206  	}))
   207  	defer ts.Close()
   208  	conf := newConf(ts.URL)
   209  	tok, err := conf.Exchange(context.Background(), "exchange-code")
   210  	if err != nil {
   211  		t.Error(err)
   212  	}
   213  	if !tok.Valid() {
   214  		t.Fatalf("Token invalid. Got: %#v", tok)
   215  	}
   216  	if tok.AccessToken != "90d64460d14870c08c81352a05dedd3465940a7c" {
   217  		t.Errorf("Unexpected access token, %#v.", tok.AccessToken)
   218  	}
   219  	if tok.TokenType != "bearer" {
   220  		t.Errorf("Unexpected token type, %#v.", tok.TokenType)
   221  	}
   222  	scope := tok.Extra("scope")
   223  	if scope != "user" {
   224  		t.Errorf("Unexpected value for scope: %v", scope)
   225  	}
   226  	expiresIn := tok.Extra("expires_in")
   227  	if expiresIn != float64(86400) {
   228  		t.Errorf("Unexpected non-numeric value for expires_in: %v", expiresIn)
   229  	}
   230  }
   231  
   232  func TestExtraValueRetrieval(t *testing.T) {
   233  	values := url.Values{}
   234  	kvmap := map[string]string{
   235  		"scope": "user", "token_type": "bearer", "expires_in": "86400.92",
   236  		"server_time": "1443571905.5606415", "referer_ip": "10.0.0.1",
   237  		"etag": "\"afZYj912P4alikMz_P11982\"", "request_id": "86400",
   238  		"untrimmed": "  untrimmed  ",
   239  	}
   240  	for key, value := range kvmap {
   241  		values.Set(key, value)
   242  	}
   243  
   244  	tok := Token{raw: values}
   245  	scope := tok.Extra("scope")
   246  	if got, want := scope, "user"; got != want {
   247  		t.Errorf("got scope = %q; want %q", got, want)
   248  	}
   249  	serverTime := tok.Extra("server_time")
   250  	if got, want := serverTime, 1443571905.5606415; got != want {
   251  		t.Errorf("got server_time value = %v; want %v", got, want)
   252  	}
   253  	refererIP := tok.Extra("referer_ip")
   254  	if got, want := refererIP, "10.0.0.1"; got != want {
   255  		t.Errorf("got referer_ip value = %v, want %v", got, want)
   256  	}
   257  	expiresIn := tok.Extra("expires_in")
   258  	if got, want := expiresIn, 86400.92; got != want {
   259  		t.Errorf("got expires_in value = %v, want %v", got, want)
   260  	}
   261  	requestID := tok.Extra("request_id")
   262  	if got, want := requestID, int64(86400); got != want {
   263  		t.Errorf("got request_id value = %v, want %v", got, want)
   264  	}
   265  	untrimmed := tok.Extra("untrimmed")
   266  	if got, want := untrimmed, "  untrimmed  "; got != want {
   267  		t.Errorf("got untrimmed = %q; want %q", got, want)
   268  	}
   269  }
   270  
   271  const day = 24 * time.Hour
   272  
   273  func TestExchangeRequest_JSONResponse_Expiry(t *testing.T) {
   274  	seconds := int32(day.Seconds())
   275  	for _, c := range []struct {
   276  		name        string
   277  		expires     string
   278  		want        bool
   279  		nullExpires bool
   280  	}{
   281  		{"normal", fmt.Sprintf(`"expires_in": %d`, seconds), true, false},
   282  		{"paypal", fmt.Sprintf(`"expires_in": "%d"`, seconds), true, false},
   283  		{"issue_239", fmt.Sprintf(`"expires_in": null`), true, true},
   284  
   285  		{"wrong_type", `"expires_in": false`, false, false},
   286  		{"wrong_type2", `"expires_in": {}`, false, false},
   287  		{"wrong_value", `"expires_in": "zzz"`, false, false},
   288  	} {
   289  		t.Run(c.name, func(t *testing.T) {
   290  			testExchangeRequest_JSONResponse_expiry(t, c.expires, c.want, c.nullExpires)
   291  		})
   292  	}
   293  }
   294  
   295  func testExchangeRequest_JSONResponse_expiry(t *testing.T, exp string, want, nullExpires bool) {
   296  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   297  		w.Header().Set("Content-Type", "application/json")
   298  		w.Write([]byte(fmt.Sprintf(`{"access_token": "90d", "scope": "user", "token_type": "bearer", %s}`, exp)))
   299  	}))
   300  	defer ts.Close()
   301  	conf := newConf(ts.URL)
   302  	t1 := time.Now().Add(day)
   303  	tok, err := conf.Exchange(context.Background(), "exchange-code")
   304  	t2 := t1.Add(day)
   305  
   306  	if got := (err == nil); got != want {
   307  		if want {
   308  			t.Errorf("unexpected error: got %v", err)
   309  		} else {
   310  			t.Errorf("unexpected success")
   311  		}
   312  	}
   313  	if !want {
   314  		return
   315  	}
   316  	if !tok.Valid() {
   317  		t.Fatalf("Token invalid. Got: %#v", tok)
   318  	}
   319  	expiry := tok.Expiry
   320  
   321  	if nullExpires && expiry.IsZero() {
   322  		return
   323  	}
   324  	if expiry.Before(t1) || expiry.After(t2) {
   325  		t.Errorf("Unexpected value for Expiry: %v (should be between %v and %v)", expiry, t1, t2)
   326  	}
   327  }
   328  
   329  func TestExchangeRequest_BadResponse(t *testing.T) {
   330  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   331  		w.Header().Set("Content-Type", "application/json")
   332  		w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
   333  	}))
   334  	defer ts.Close()
   335  	conf := newConf(ts.URL)
   336  	_, err := conf.Exchange(context.Background(), "code")
   337  	if err == nil {
   338  		t.Error("expected error from missing access_token")
   339  	}
   340  }
   341  
   342  func TestExchangeRequest_BadResponseType(t *testing.T) {
   343  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   344  		w.Header().Set("Content-Type", "application/json")
   345  		w.Write([]byte(`{"access_token":123,  "scope": "user", "token_type": "bearer"}`))
   346  	}))
   347  	defer ts.Close()
   348  	conf := newConf(ts.URL)
   349  	_, err := conf.Exchange(context.Background(), "exchange-code")
   350  	if err == nil {
   351  		t.Error("expected error from non-string access_token")
   352  	}
   353  }
   354  
   355  func TestExchangeRequest_NonBasicAuth(t *testing.T) {
   356  	tr := &mockTransport{
   357  		rt: func(r *http.Request) (w *http.Response, err error) {
   358  			headerAuth := r.Header.Get("Authorization")
   359  			if headerAuth != "" {
   360  				t.Errorf("Unexpected authorization header %q", headerAuth)
   361  			}
   362  			return nil, errors.New("no response")
   363  		},
   364  	}
   365  	c := &http.Client{Transport: tr}
   366  	conf := &Config{
   367  		ClientID: "CLIENT_ID",
   368  		Endpoint: Endpoint{
   369  			AuthURL:   "https://accounts.google.com/auth",
   370  			TokenURL:  "https://accounts.google.com/token",
   371  			AuthStyle: AuthStyleInParams,
   372  		},
   373  	}
   374  
   375  	ctx := context.WithValue(context.Background(), HTTPClient, c)
   376  	conf.Exchange(ctx, "code")
   377  }
   378  
   379  func TestPasswordCredentialsTokenRequest(t *testing.T) {
   380  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   381  		defer r.Body.Close()
   382  		expected := "/token"
   383  		if r.URL.String() != expected {
   384  			t.Errorf("URL = %q; want %q", r.URL, expected)
   385  		}
   386  		headerAuth := r.Header.Get("Authorization")
   387  		expected = "Basic Q0xJRU5UX0lEOkNMSUVOVF9TRUNSRVQ="
   388  		if headerAuth != expected {
   389  			t.Errorf("Authorization header = %q; want %q", headerAuth, expected)
   390  		}
   391  		headerContentType := r.Header.Get("Content-Type")
   392  		expected = "application/x-www-form-urlencoded"
   393  		if headerContentType != expected {
   394  			t.Errorf("Content-Type header = %q; want %q", headerContentType, expected)
   395  		}
   396  		body, err := ioutil.ReadAll(r.Body)
   397  		if err != nil {
   398  			t.Errorf("Failed reading request body: %s.", err)
   399  		}
   400  		expected = "grant_type=password&password=password1&scope=scope1+scope2&username=user1"
   401  		if string(body) != expected {
   402  			t.Errorf("res.Body = %q; want %q", string(body), expected)
   403  		}
   404  		w.Header().Set("Content-Type", "application/x-www-form-urlencoded")
   405  		w.Write([]byte("access_token=90d64460d14870c08c81352a05dedd3465940a7c&scope=user&token_type=bearer"))
   406  	}))
   407  	defer ts.Close()
   408  	conf := newConf(ts.URL)
   409  	tok, err := conf.PasswordCredentialsToken(context.Background(), "user1", "password1")
   410  	if err != nil {
   411  		t.Error(err)
   412  	}
   413  	if !tok.Valid() {
   414  		t.Fatalf("Token invalid. Got: %#v", tok)
   415  	}
   416  	expected := "90d64460d14870c08c81352a05dedd3465940a7c"
   417  	if tok.AccessToken != expected {
   418  		t.Errorf("AccessToken = %q; want %q", tok.AccessToken, expected)
   419  	}
   420  	expected = "bearer"
   421  	if tok.TokenType != expected {
   422  		t.Errorf("TokenType = %q; want %q", tok.TokenType, expected)
   423  	}
   424  }
   425  
   426  func TestTokenRefreshRequest(t *testing.T) {
   427  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   428  		if r.URL.String() == "/somethingelse" {
   429  			return
   430  		}
   431  		if r.URL.String() != "/token" {
   432  			t.Errorf("Unexpected token refresh request URL %q", r.URL)
   433  		}
   434  		headerContentType := r.Header.Get("Content-Type")
   435  		if headerContentType != "application/x-www-form-urlencoded" {
   436  			t.Errorf("Unexpected Content-Type header %q", headerContentType)
   437  		}
   438  		body, _ := ioutil.ReadAll(r.Body)
   439  		if string(body) != "grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
   440  			t.Errorf("Unexpected refresh token payload %q", body)
   441  		}
   442  		w.Header().Set("Content-Type", "application/json")
   443  		io.WriteString(w, `{"access_token": "foo", "refresh_token": "bar"}`)
   444  	}))
   445  	defer ts.Close()
   446  	conf := newConf(ts.URL)
   447  	c := conf.Client(context.Background(), &Token{RefreshToken: "REFRESH_TOKEN"})
   448  	c.Get(ts.URL + "/somethingelse")
   449  }
   450  
   451  func TestFetchWithNoRefreshToken(t *testing.T) {
   452  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   453  		if r.URL.String() == "/somethingelse" {
   454  			return
   455  		}
   456  		if r.URL.String() != "/token" {
   457  			t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
   458  		}
   459  		headerContentType := r.Header.Get("Content-Type")
   460  		if headerContentType != "application/x-www-form-urlencoded" {
   461  			t.Errorf("Unexpected Content-Type header, %v is found.", headerContentType)
   462  		}
   463  		body, _ := ioutil.ReadAll(r.Body)
   464  		if string(body) != "client_id=CLIENT_ID&grant_type=refresh_token&refresh_token=REFRESH_TOKEN" {
   465  			t.Errorf("Unexpected refresh token payload, %v is found.", string(body))
   466  		}
   467  	}))
   468  	defer ts.Close()
   469  	conf := newConf(ts.URL)
   470  	c := conf.Client(context.Background(), nil)
   471  	_, err := c.Get(ts.URL + "/somethingelse")
   472  	if err == nil {
   473  		t.Errorf("Fetch should return an error if no refresh token is set")
   474  	}
   475  }
   476  
   477  func TestTokenRetrieveError(t *testing.T) {
   478  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   479  		if r.URL.String() != "/token" {
   480  			t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
   481  		}
   482  		w.Header().Set("Content-type", "application/json")
   483  		// "The authorization server responds with an HTTP 400 (Bad Request)" https://www.rfc-editor.org/rfc/rfc6749#section-5.2
   484  		w.WriteHeader(http.StatusBadRequest)
   485  		w.Write([]byte(`{"error": "invalid_grant"}`))
   486  	}))
   487  	defer ts.Close()
   488  	conf := newConf(ts.URL)
   489  	_, err := conf.Exchange(context.Background(), "exchange-code")
   490  	if err == nil {
   491  		t.Fatalf("got no error, expected one")
   492  	}
   493  	re, ok := err.(*RetrieveError)
   494  	if !ok {
   495  		t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
   496  	}
   497  	expected := `oauth2: "invalid_grant"`
   498  	if errStr := err.Error(); errStr != expected {
   499  		t.Fatalf("got %#v, expected %#v", errStr, expected)
   500  	}
   501  	expected = "invalid_grant"
   502  	if re.ErrorCode != expected {
   503  		t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
   504  	}
   505  }
   506  
   507  // TestTokenRetrieveError200 tests handling of unorthodox server that returns 200 in error case
   508  func TestTokenRetrieveError200(t *testing.T) {
   509  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   510  		if r.URL.String() != "/token" {
   511  			t.Errorf("Unexpected token refresh request URL, %v is found.", r.URL)
   512  		}
   513  		w.Header().Set("Content-type", "application/json")
   514  		w.Write([]byte(`{"error": "invalid_grant"}`))
   515  	}))
   516  	defer ts.Close()
   517  	conf := newConf(ts.URL)
   518  	_, err := conf.Exchange(context.Background(), "exchange-code")
   519  	if err == nil {
   520  		t.Fatalf("got no error, expected one")
   521  	}
   522  	re, ok := err.(*RetrieveError)
   523  	if !ok {
   524  		t.Fatalf("got %T error, expected *RetrieveError; error was: %v", err, err)
   525  	}
   526  	expected := `oauth2: "invalid_grant"`
   527  	if errStr := err.Error(); errStr != expected {
   528  		t.Fatalf("got %#v, expected %#v", errStr, expected)
   529  	}
   530  	expected = "invalid_grant"
   531  	if re.ErrorCode != expected {
   532  		t.Fatalf("got %#v, expected %#v", re.ErrorCode, expected)
   533  	}
   534  }
   535  
   536  func TestRefreshToken_RefreshTokenReplacement(t *testing.T) {
   537  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   538  		w.Header().Set("Content-Type", "application/json")
   539  		w.Write([]byte(`{"access_token":"ACCESS_TOKEN",  "scope": "user", "token_type": "bearer", "refresh_token": "NEW_REFRESH_TOKEN"}`))
   540  		return
   541  	}))
   542  	defer ts.Close()
   543  	conf := newConf(ts.URL)
   544  	tkr := conf.TokenSource(context.Background(), &Token{RefreshToken: "OLD_REFRESH_TOKEN"})
   545  	tk, err := tkr.Token()
   546  	if err != nil {
   547  		t.Errorf("got err = %v; want none", err)
   548  		return
   549  	}
   550  	if want := "NEW_REFRESH_TOKEN"; tk.RefreshToken != want {
   551  		t.Errorf("RefreshToken = %q; want %q", tk.RefreshToken, want)
   552  	}
   553  }
   554  
   555  func TestRefreshToken_RefreshTokenPreservation(t *testing.T) {
   556  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   557  		w.Header().Set("Content-Type", "application/json")
   558  		w.Write([]byte(`{"access_token":"ACCESS_TOKEN",  "scope": "user", "token_type": "bearer"}`))
   559  		return
   560  	}))
   561  	defer ts.Close()
   562  	conf := newConf(ts.URL)
   563  	const oldRefreshToken = "OLD_REFRESH_TOKEN"
   564  	tkr := conf.TokenSource(context.Background(), &Token{RefreshToken: oldRefreshToken})
   565  	tk, err := tkr.Token()
   566  	if err != nil {
   567  		t.Fatalf("got err = %v; want none", err)
   568  	}
   569  	if tk.RefreshToken != oldRefreshToken {
   570  		t.Errorf("RefreshToken = %q; want %q", tk.RefreshToken, oldRefreshToken)
   571  	}
   572  }
   573  
   574  func TestConfigClientWithToken(t *testing.T) {
   575  	tok := &Token{
   576  		AccessToken: "abc123",
   577  	}
   578  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   579  		if got, want := r.Header.Get("Authorization"), fmt.Sprintf("Bearer %s", tok.AccessToken); got != want {
   580  			t.Errorf("Authorization header = %q; want %q", got, want)
   581  		}
   582  		return
   583  	}))
   584  	defer ts.Close()
   585  	conf := newConf(ts.URL)
   586  
   587  	c := conf.Client(context.Background(), tok)
   588  	req, err := http.NewRequest("GET", ts.URL, nil)
   589  	if err != nil {
   590  		t.Error(err)
   591  	}
   592  	_, err = c.Do(req)
   593  	if err != nil {
   594  		t.Error(err)
   595  	}
   596  }
   597  

View as plain text