1
2
3
4
5 package idtoken
6
7 import (
8 "bytes"
9 "context"
10 "crypto"
11 "crypto/ecdsa"
12 "crypto/elliptic"
13 "crypto/rand"
14 "crypto/rsa"
15 "encoding/base64"
16 "encoding/json"
17 "io"
18 "math/big"
19 "net/http"
20 "testing"
21 "time"
22
23 "google.golang.org/api/option"
24 )
25
26 const (
27 keyID = "1234"
28 testAudience = "test-audience"
29 expiry int64 = 233431200
30 )
31
32 var (
33 beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
34 afterExp = func() time.Time { return time.Unix(expiry+1, 0) }
35 )
36
37 func TestValidateRS256(t *testing.T) {
38 idToken, pk := createRS256JWT(t)
39 tests := []struct {
40 name string
41 keyID string
42 n *big.Int
43 e int
44 nowFunc func() time.Time
45 wantErr bool
46 }{
47 {
48 name: "works",
49 keyID: keyID,
50 n: pk.N,
51 e: pk.E,
52 nowFunc: beforeExp,
53 wantErr: false,
54 },
55 {
56 name: "no matching key",
57 keyID: "5678",
58 n: pk.N,
59 e: pk.E,
60 nowFunc: beforeExp,
61 wantErr: true,
62 },
63 {
64 name: "sig does not match",
65 keyID: keyID,
66 n: new(big.Int).SetBytes([]byte("42")),
67 e: 42,
68 nowFunc: beforeExp,
69 wantErr: true,
70 },
71 {
72 name: "token expired",
73 keyID: keyID,
74 n: pk.N,
75 e: pk.E,
76 nowFunc: afterExp,
77 wantErr: true,
78 },
79 }
80
81 for _, tt := range tests {
82 t.Run(tt.name, func(t *testing.T) {
83 client := &http.Client{
84 Transport: RoundTripFn(func(req *http.Request) *http.Response {
85 cr := certResponse{
86 Keys: []jwk{
87 {
88 Kid: tt.keyID,
89 N: base64.RawURLEncoding.EncodeToString(tt.n.Bytes()),
90 E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()),
91 },
92 },
93 }
94 b, err := json.Marshal(&cr)
95 if err != nil {
96 t.Fatalf("unable to marshal response: %v", err)
97 }
98 return &http.Response{
99 StatusCode: 200,
100 Body: io.NopCloser(bytes.NewReader(b)),
101 Header: make(http.Header),
102 }
103 }),
104 }
105 oldNow := now
106 defer func() { now = oldNow }()
107 now = tt.nowFunc
108
109 v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
110 if err != nil {
111 t.Fatalf("NewValidator(...) = %q, want nil", err)
112 }
113 payload, err := v.Validate(context.Background(), idToken, testAudience)
114 if tt.wantErr && err != nil {
115
116 return
117 }
118 if !tt.wantErr && err != nil {
119 t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err)
120 }
121 if tt.wantErr && err == nil {
122 t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience)
123 }
124 if payload == nil {
125 t.Fatalf("Got nil payload, err: %v", err)
126 }
127 if payload.Audience != testAudience {
128 t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience)
129 }
130 if len(payload.Claims) == 0 {
131 t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims)
132 }
133 if got, ok := payload.Claims["aud"]; !ok {
134 t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims)
135 } else {
136 got, ok := got.(string)
137 if !ok {
138 t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims)
139 }
140 if got != testAudience {
141 t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got)
142 }
143 }
144 })
145 }
146 }
147
148 func TestValidateES256(t *testing.T) {
149 idToken, pk := createES256JWT(t)
150 tests := []struct {
151 name string
152 keyID string
153 x *big.Int
154 y *big.Int
155 nowFunc func() time.Time
156 wantErr bool
157 }{
158 {
159 name: "works",
160 keyID: keyID,
161 x: pk.X,
162 y: pk.Y,
163 nowFunc: beforeExp,
164 wantErr: false,
165 },
166 {
167 name: "no matching key",
168 keyID: "5678",
169 x: pk.X,
170 y: pk.Y,
171 nowFunc: beforeExp,
172 wantErr: true,
173 },
174 {
175 name: "sig does not match",
176 keyID: keyID,
177 x: new(big.Int),
178 y: new(big.Int),
179 nowFunc: beforeExp,
180 wantErr: true,
181 },
182 {
183 name: "token expired",
184 keyID: keyID,
185 x: pk.X,
186 y: pk.Y,
187 nowFunc: afterExp,
188 wantErr: true,
189 },
190 }
191 for _, tt := range tests {
192 t.Run(tt.name, func(t *testing.T) {
193 client := &http.Client{
194 Transport: RoundTripFn(func(req *http.Request) *http.Response {
195 cr := certResponse{
196 Keys: []jwk{
197 {
198 Kid: tt.keyID,
199 X: base64.RawURLEncoding.EncodeToString(tt.x.Bytes()),
200 Y: base64.RawURLEncoding.EncodeToString(tt.y.Bytes()),
201 },
202 },
203 }
204 b, err := json.Marshal(&cr)
205 if err != nil {
206 t.Fatalf("unable to marshal response: %v", err)
207 }
208 return &http.Response{
209 StatusCode: 200,
210 Body: io.NopCloser(bytes.NewReader(b)),
211 Header: make(http.Header),
212 }
213 }),
214 }
215 oldNow := now
216 defer func() { now = oldNow }()
217 now = tt.nowFunc
218
219 v, err := NewValidator(context.Background(), option.WithHTTPClient(client))
220 if err != nil {
221 t.Fatalf("NewValidator(...) = %q, want nil", err)
222 }
223 payload, err := v.Validate(context.Background(), idToken, testAudience)
224 if !tt.wantErr && err != nil {
225 t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
226 }
227 if !tt.wantErr && payload.Audience != testAudience {
228 t.Fatalf("got %v, want %v", payload.Audience, testAudience)
229 }
230 })
231 }
232 }
233
234 func TestParsePayload(t *testing.T) {
235 idToken, _ := createRS256JWT(t)
236 tests := []struct {
237 name string
238 token string
239 wantPayloadAudience string
240 wantErr bool
241 }{{
242 name: "valid token",
243 token: idToken,
244 wantPayloadAudience: testAudience,
245 }, {
246 name: "unparseable token",
247 token: "aaa.bbb.ccc",
248 wantErr: true,
249 }}
250
251 for _, tt := range tests {
252 t.Run(tt.name, func(t *testing.T) {
253 payload, err := ParsePayload(tt.token)
254 gotErr := err != nil
255 if gotErr != tt.wantErr {
256 t.Errorf("ParsePayload(%q) got error %v, wantErr = %v", tt.token, err, tt.wantErr)
257 }
258 if tt.wantPayloadAudience != "" {
259 if payload == nil || payload.Audience != tt.wantPayloadAudience {
260 t.Errorf("ParsePayload(%q) got payload %+v, want payload with audience = %q", tt.token, payload, tt.wantPayloadAudience)
261 }
262 }
263 })
264 }
265 }
266
267 func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) {
268 t.Helper()
269 token := commonToken(t, "ES256")
270 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
271 if err != nil {
272 t.Fatalf("unable to generate key: %v", err)
273 }
274 r, s, err := ecdsa.Sign(rand.Reader, privateKey, token.hashedContent())
275 if err != nil {
276 t.Fatalf("unable to sign content: %v", err)
277 }
278 rb := r.Bytes()
279 lPadded := make([]byte, es256KeySize)
280 copy(lPadded[es256KeySize-len(rb):], rb)
281 var sig []byte
282 sig = append(sig, lPadded...)
283 sig = append(sig, s.Bytes()...)
284 token.signature = base64.RawURLEncoding.EncodeToString(sig)
285 return token.String(), privateKey.PublicKey
286 }
287
288 func createRS256JWT(t *testing.T) (string, rsa.PublicKey) {
289 t.Helper()
290 token := commonToken(t, "RS256")
291 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
292 if err != nil {
293 t.Fatalf("unable to generate key: %v", err)
294 }
295 sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, token.hashedContent())
296 if err != nil {
297 t.Fatalf("unable to sign content: %v", err)
298 }
299 token.signature = base64.RawURLEncoding.EncodeToString(sig)
300 return token.String(), privateKey.PublicKey
301 }
302
303 func commonToken(t *testing.T, alg string) *jwt {
304 t.Helper()
305 header := jwtHeader{
306 KeyID: keyID,
307 Algorithm: alg,
308 Type: "JWT",
309 }
310 payload := Payload{
311 Issuer: "example.com",
312 Audience: testAudience,
313 Expires: expiry,
314 }
315
316 hb, err := json.Marshal(&header)
317 if err != nil {
318 t.Fatalf("unable to marshall header: %v", err)
319 }
320 pb, err := json.Marshal(&payload)
321 if err != nil {
322 t.Fatalf("unable to marshall payload: %v", err)
323 }
324 eb := base64.RawURLEncoding.EncodeToString(hb)
325 ep := base64.RawURLEncoding.EncodeToString(pb)
326 return &jwt{
327 header: eb,
328 payload: ep,
329 }
330 }
331
332 type RoundTripFn func(req *http.Request) *http.Response
333
334 func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }
335
View as plain text