1
2
3
4
5 package jwt
6
7 import (
8 "context"
9 "encoding/base64"
10 "encoding/json"
11 "fmt"
12 "net/http"
13 "net/http/httptest"
14 "reflect"
15 "strings"
16 "testing"
17
18 "golang.org/x/oauth2"
19 "golang.org/x/oauth2/jws"
20 )
21
22 var dummyPrivateKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
23 MIIEpAIBAAKCAQEAx4fm7dngEmOULNmAs1IGZ9Apfzh+BkaQ1dzkmbUgpcoghucE
24 DZRnAGd2aPyB6skGMXUytWQvNYav0WTR00wFtX1ohWTfv68HGXJ8QXCpyoSKSSFY
25 fuP9X36wBSkSX9J5DVgiuzD5VBdzUISSmapjKm+DcbRALjz6OUIPEWi1Tjl6p5RK
26 1w41qdbmt7E5/kGhKLDuT7+M83g4VWhgIvaAXtnhklDAggilPPa8ZJ1IFe31lNlr
27 k4DRk38nc6sEutdf3RL7QoH7FBusI7uXV03DC6dwN1kP4GE7bjJhcRb/7jYt7CQ9
28 /E9Exz3c0yAp0yrTg0Fwh+qxfH9dKwN52S7SBwIDAQABAoIBAQCaCs26K07WY5Jt
29 3a2Cw3y2gPrIgTCqX6hJs7O5ByEhXZ8nBwsWANBUe4vrGaajQHdLj5OKfsIDrOvn
30 2NI1MqflqeAbu/kR32q3tq8/Rl+PPiwUsW3E6Pcf1orGMSNCXxeducF2iySySzh3
31 nSIhCG5uwJDWI7a4+9KiieFgK1pt/Iv30q1SQS8IEntTfXYwANQrfKUVMmVF9aIK
32 6/WZE2yd5+q3wVVIJ6jsmTzoDCX6QQkkJICIYwCkglmVy5AeTckOVwcXL0jqw5Kf
33 5/soZJQwLEyBoQq7Kbpa26QHq+CJONetPP8Ssy8MJJXBT+u/bSseMb3Zsr5cr43e
34 DJOhwsThAoGBAPY6rPKl2NT/K7XfRCGm1sbWjUQyDShscwuWJ5+kD0yudnT/ZEJ1
35 M3+KS/iOOAoHDdEDi9crRvMl0UfNa8MAcDKHflzxg2jg/QI+fTBjPP5GOX0lkZ9g
36 z6VePoVoQw2gpPFVNPPTxKfk27tEzbaffvOLGBEih0Kb7HTINkW8rIlzAoGBAM9y
37 1yr+jvfS1cGFtNU+Gotoihw2eMKtIqR03Yn3n0PK1nVCDKqwdUqCypz4+ml6cxRK
38 J8+Pfdh7D+ZJd4LEG6Y4QRDLuv5OA700tUoSHxMSNn3q9As4+T3MUyYxWKvTeu3U
39 f2NWP9ePU0lV8ttk7YlpVRaPQmc1qwooBA/z/8AdAoGAW9x0HWqmRICWTBnpjyxx
40 QGlW9rQ9mHEtUotIaRSJ6K/F3cxSGUEkX1a3FRnp6kPLcckC6NlqdNgNBd6rb2rA
41 cPl/uSkZP42Als+9YMoFPU/xrrDPbUhu72EDrj3Bllnyb168jKLa4VBOccUvggxr
42 Dm08I1hgYgdN5huzs7y6GeUCgYEAj+AZJSOJ6o1aXS6rfV3mMRve9bQ9yt8jcKXw
43 5HhOCEmMtaSKfnOF1Ziih34Sxsb7O2428DiX0mV/YHtBnPsAJidL0SdLWIapBzeg
44 KHArByIRkwE6IvJvwpGMdaex1PIGhx5i/3VZL9qiq/ElT05PhIb+UXgoWMabCp84
45 OgxDK20CgYAeaFo8BdQ7FmVX2+EEejF+8xSge6WVLtkaon8bqcn6P0O8lLypoOhd
46 mJAYH8WU+UAy9pecUnDZj14LAGNVmYcse8HFX71MoshnvCTFEPVo4rZxIAGwMpeJ
47 5jgQ3slYLpqrGlcbLgUXBUgzEO684Wk/UV9DFPlHALVqCfXQ9dpJPg==
48 -----END RSA PRIVATE KEY-----`)
49
50 func TestJWTFetch_JSONResponse(t *testing.T) {
51 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
52 w.Header().Set("Content-Type", "application/json")
53 w.Write([]byte(`{
54 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
55 "scope": "user",
56 "token_type": "bearer",
57 "expires_in": 3600
58 }`))
59 }))
60 defer ts.Close()
61
62 conf := &Config{
63 Email: "aaa@xxx.com",
64 PrivateKey: dummyPrivateKey,
65 TokenURL: ts.URL,
66 }
67 tok, err := conf.TokenSource(context.Background()).Token()
68 if err != nil {
69 t.Fatal(err)
70 }
71 if !tok.Valid() {
72 t.Errorf("got invalid token: %v", tok)
73 }
74 if got, want := tok.AccessToken, "90d64460d14870c08c81352a05dedd3465940a7c"; got != want {
75 t.Errorf("access token = %q; want %q", got, want)
76 }
77 if got, want := tok.TokenType, "bearer"; got != want {
78 t.Errorf("token type = %q; want %q", got, want)
79 }
80 if got := tok.Expiry.IsZero(); got {
81 t.Errorf("token expiry = %v, want none", got)
82 }
83 scope := tok.Extra("scope")
84 if got, want := scope, "user"; got != want {
85 t.Errorf("scope = %q; want %q", got, want)
86 }
87 }
88
89 func TestJWTFetch_BadResponse(t *testing.T) {
90 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
91 w.Header().Set("Content-Type", "application/json")
92 w.Write([]byte(`{"scope": "user", "token_type": "bearer"}`))
93 }))
94 defer ts.Close()
95
96 conf := &Config{
97 Email: "aaa@xxx.com",
98 PrivateKey: dummyPrivateKey,
99 TokenURL: ts.URL,
100 }
101 tok, err := conf.TokenSource(context.Background()).Token()
102 if err != nil {
103 t.Fatal(err)
104 }
105 if tok == nil {
106 t.Fatalf("got nil token; want token")
107 }
108 if tok.Valid() {
109 t.Errorf("got invalid token: %v", tok)
110 }
111 if got, want := tok.AccessToken, ""; got != want {
112 t.Errorf("access token = %q; want %q", got, want)
113 }
114 if got, want := tok.TokenType, "bearer"; got != want {
115 t.Errorf("token type = %q; want %q", got, want)
116 }
117 scope := tok.Extra("scope")
118 if got, want := scope, "user"; got != want {
119 t.Errorf("token scope = %q; want %q", got, want)
120 }
121 }
122
123 func TestJWTFetch_BadResponseType(t *testing.T) {
124 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
125 w.Header().Set("Content-Type", "application/json")
126 w.Write([]byte(`{"access_token":123, "scope": "user", "token_type": "bearer"}`))
127 }))
128 defer ts.Close()
129 conf := &Config{
130 Email: "aaa@xxx.com",
131 PrivateKey: dummyPrivateKey,
132 TokenURL: ts.URL,
133 }
134 tok, err := conf.TokenSource(context.Background()).Token()
135 if err == nil {
136 t.Error("got a token; expected error")
137 if got, want := tok.AccessToken, ""; got != want {
138 t.Errorf("access token = %q; want %q", got, want)
139 }
140 }
141 }
142
143 func TestJWTFetch_Assertion(t *testing.T) {
144 var assertion string
145 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
146 r.ParseForm()
147 assertion = r.Form.Get("assertion")
148
149 w.Header().Set("Content-Type", "application/json")
150 w.Write([]byte(`{
151 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
152 "scope": "user",
153 "token_type": "bearer",
154 "expires_in": 3600
155 }`))
156 }))
157 defer ts.Close()
158
159 conf := &Config{
160 Email: "aaa@xxx.com",
161 PrivateKey: dummyPrivateKey,
162 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
163 TokenURL: ts.URL,
164 }
165
166 _, err := conf.TokenSource(context.Background()).Token()
167 if err != nil {
168 t.Fatalf("Failed to fetch token: %v", err)
169 }
170
171 parts := strings.Split(assertion, ".")
172 if len(parts) != 3 {
173 t.Fatalf("assertion = %q; want 3 parts", assertion)
174 }
175 gotjson, err := base64.RawURLEncoding.DecodeString(parts[0])
176 if err != nil {
177 t.Fatalf("invalid token header; err = %v", err)
178 }
179
180 got := jws.Header{}
181 if err := json.Unmarshal(gotjson, &got); err != nil {
182 t.Errorf("failed to unmarshal json token header = %q; err = %v", gotjson, err)
183 }
184
185 want := jws.Header{
186 Algorithm: "RS256",
187 Typ: "JWT",
188 KeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
189 }
190 if got != want {
191 t.Errorf("access token header = %q; want %q", got, want)
192 }
193 }
194
195 func TestJWTFetch_AssertionPayload(t *testing.T) {
196 var assertion string
197 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
198 r.ParseForm()
199 assertion = r.Form.Get("assertion")
200
201 w.Header().Set("Content-Type", "application/json")
202 w.Write([]byte(`{
203 "access_token": "90d64460d14870c08c81352a05dedd3465940a7c",
204 "scope": "user",
205 "token_type": "bearer",
206 "expires_in": 3600
207 }`))
208 }))
209 defer ts.Close()
210
211 for _, conf := range []*Config{
212 {
213 Email: "aaa1@xxx.com",
214 PrivateKey: dummyPrivateKey,
215 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
216 TokenURL: ts.URL,
217 },
218 {
219 Email: "aaa2@xxx.com",
220 PrivateKey: dummyPrivateKey,
221 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
222 TokenURL: ts.URL,
223 Audience: "https://example.com",
224 },
225 {
226 Email: "aaa2@xxx.com",
227 PrivateKey: dummyPrivateKey,
228 PrivateKeyID: "ABCDEFGHIJKLMNOPQRSTUVWXYZ",
229 TokenURL: ts.URL,
230 PrivateClaims: map[string]interface{}{
231 "private0": "claim0",
232 "private1": "claim1",
233 },
234 },
235 } {
236 t.Run(conf.Email, func(t *testing.T) {
237 _, err := conf.TokenSource(context.Background()).Token()
238 if err != nil {
239 t.Fatalf("Failed to fetch token: %v", err)
240 }
241
242 parts := strings.Split(assertion, ".")
243 if len(parts) != 3 {
244 t.Fatalf("assertion = %q; want 3 parts", assertion)
245 }
246 gotjson, err := base64.RawURLEncoding.DecodeString(parts[1])
247 if err != nil {
248 t.Fatalf("invalid token payload; err = %v", err)
249 }
250
251 claimSet := jws.ClaimSet{}
252 if err := json.Unmarshal(gotjson, &claimSet); err != nil {
253 t.Errorf("failed to unmarshal json token payload = %q; err = %v", gotjson, err)
254 }
255
256 if got, want := claimSet.Iss, conf.Email; got != want {
257 t.Errorf("payload email = %q; want %q", got, want)
258 }
259 if got, want := claimSet.Scope, strings.Join(conf.Scopes, " "); got != want {
260 t.Errorf("payload scope = %q; want %q", got, want)
261 }
262 aud := conf.TokenURL
263 if conf.Audience != "" {
264 aud = conf.Audience
265 }
266 if got, want := claimSet.Aud, aud; got != want {
267 t.Errorf("payload audience = %q; want %q", got, want)
268 }
269 if got, want := claimSet.Sub, conf.Subject; got != want {
270 t.Errorf("payload subject = %q; want %q", got, want)
271 }
272 if got, want := claimSet.Prn, conf.Subject; got != want {
273 t.Errorf("payload prn = %q; want %q", got, want)
274 }
275 if len(conf.PrivateClaims) > 0 {
276 var got interface{}
277 if err := json.Unmarshal(gotjson, &got); err != nil {
278 t.Errorf("failed to parse payload; err = %q", err)
279 }
280 m := got.(map[string]interface{})
281 for v, k := range conf.PrivateClaims {
282 if !reflect.DeepEqual(m[v], k) {
283 t.Errorf("payload private claims key = %q: got %#v; want %#v", v, m[v], k)
284 }
285 }
286 }
287 })
288 }
289 }
290
291 func TestTokenRetrieveError(t *testing.T) {
292 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
293 w.Header().Set("Content-type", "application/json")
294 w.WriteHeader(http.StatusBadRequest)
295 w.Write([]byte(`{"error": "invalid_grant"}`))
296 }))
297 defer ts.Close()
298
299 conf := &Config{
300 Email: "aaa@xxx.com",
301 PrivateKey: dummyPrivateKey,
302 TokenURL: ts.URL,
303 }
304
305 _, err := conf.TokenSource(context.Background()).Token()
306 if err == nil {
307 t.Fatalf("got no error, expected one")
308 }
309 _, ok := err.(*oauth2.RetrieveError)
310 if !ok {
311 t.Fatalf("got %T error, expected *RetrieveError", err)
312 }
313
314 expected := fmt.Sprintf("oauth2: cannot fetch token: %v\nResponse: %s", "400 Bad Request", `{"error": "invalid_grant"}`)
315 if errStr := err.Error(); errStr != expected {
316 t.Fatalf("got %#v, expected %#v", errStr, expected)
317 }
318 }
319
View as plain text