1 package oauth2
2
3 import (
4 "errors"
5 "io"
6 "net/http"
7 "net/http/httptest"
8 "testing"
9 "time"
10 )
11
12 type tokenSource struct{ token *Token }
13
14 func (t *tokenSource) Token() (*Token, error) {
15 return t.token, nil
16 }
17
18 func TestTransportNilTokenSource(t *testing.T) {
19 tr := &Transport{}
20 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
21 defer server.Close()
22 client := &http.Client{Transport: tr}
23 resp, err := client.Get(server.URL)
24 if err == nil {
25 t.Errorf("got no errors, want an error with nil token source")
26 }
27 if resp != nil {
28 t.Errorf("Response = %v; want nil", resp)
29 }
30 }
31
32 type readCloseCounter struct {
33 CloseCount int
34 ReadErr error
35 }
36
37 func (r *readCloseCounter) Read(b []byte) (int, error) {
38 return 0, r.ReadErr
39 }
40
41 func (r *readCloseCounter) Close() error {
42 r.CloseCount++
43 return nil
44 }
45
46 func TestTransportCloseRequestBody(t *testing.T) {
47 tr := &Transport{}
48 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
49 defer server.Close()
50 client := &http.Client{Transport: tr}
51 body := &readCloseCounter{
52 ReadErr: errors.New("readCloseCounter.Read not implemented"),
53 }
54 resp, err := client.Post(server.URL, "application/json", body)
55 if err == nil {
56 t.Errorf("got no errors, want an error with nil token source")
57 }
58 if resp != nil {
59 t.Errorf("Response = %v; want nil", resp)
60 }
61 if expected := 1; body.CloseCount != expected {
62 t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
63 }
64 }
65
66 func TestTransportCloseRequestBodySuccess(t *testing.T) {
67 tr := &Transport{
68 Source: StaticTokenSource(&Token{
69 AccessToken: "abc",
70 }),
71 }
72 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {})
73 defer server.Close()
74 client := &http.Client{Transport: tr}
75 body := &readCloseCounter{
76 ReadErr: io.EOF,
77 }
78 resp, err := client.Post(server.URL, "application/json", body)
79 if err != nil {
80 t.Errorf("got error %v; expected none", err)
81 }
82 if resp == nil {
83 t.Errorf("Response is nil; expected non-nil")
84 }
85 if expected := 1; body.CloseCount != expected {
86 t.Errorf("Body was closed %d times, expected %d", body.CloseCount, expected)
87 }
88 }
89
90 func TestTransportTokenSource(t *testing.T) {
91 ts := &tokenSource{
92 token: &Token{
93 AccessToken: "abc",
94 },
95 }
96 tr := &Transport{
97 Source: ts,
98 }
99 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
100 if got, want := r.Header.Get("Authorization"), "Bearer abc"; got != want {
101 t.Errorf("Authorization header = %q; want %q", got, want)
102 }
103 })
104 defer server.Close()
105 client := &http.Client{Transport: tr}
106 res, err := client.Get(server.URL)
107 if err != nil {
108 t.Fatal(err)
109 }
110 res.Body.Close()
111 }
112
113
114 func TestTransportTokenSourceTypes(t *testing.T) {
115 const val = "abc"
116 tests := []struct {
117 key string
118 val string
119 want string
120 }{
121 {key: "bearer", val: val, want: "Bearer abc"},
122 {key: "mac", val: val, want: "MAC abc"},
123 {key: "basic", val: val, want: "Basic abc"},
124 }
125 for _, tc := range tests {
126 ts := &tokenSource{
127 token: &Token{
128 AccessToken: tc.val,
129 TokenType: tc.key,
130 },
131 }
132 tr := &Transport{
133 Source: ts,
134 }
135 server := newMockServer(func(w http.ResponseWriter, r *http.Request) {
136 if got, want := r.Header.Get("Authorization"), tc.want; got != want {
137 t.Errorf("Authorization header (%q) = %q; want %q", val, got, want)
138 }
139 })
140 defer server.Close()
141 client := &http.Client{Transport: tr}
142 res, err := client.Get(server.URL)
143 if err != nil {
144 t.Fatal(err)
145 }
146 res.Body.Close()
147 }
148 }
149
150 func TestTokenValidNoAccessToken(t *testing.T) {
151 token := &Token{}
152 if token.Valid() {
153 t.Errorf("got valid with no access token; want invalid")
154 }
155 }
156
157 func TestExpiredWithExpiry(t *testing.T) {
158 token := &Token{
159 Expiry: time.Now().Add(-5 * time.Hour),
160 }
161 if token.Valid() {
162 t.Errorf("got valid with expired token; want invalid")
163 }
164 }
165
166 func newMockServer(handler func(w http.ResponseWriter, r *http.Request)) *httptest.Server {
167 return httptest.NewServer(http.HandlerFunc(handler))
168 }
169
View as plain text