1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package idtoken
16
17 import (
18 "bytes"
19 "context"
20 "crypto"
21 "crypto/ecdsa"
22 "crypto/elliptic"
23 "crypto/rand"
24 "crypto/rsa"
25 "crypto/sha256"
26 "encoding/base64"
27 "encoding/json"
28 "fmt"
29 "io"
30 "math/big"
31 "net/http"
32 "testing"
33 "time"
34
35 "cloud.google.com/go/auth/internal/jwt"
36 )
37
38 const (
39 keyID = "1234"
40 testAudience = "test-audience"
41 expiry int64 = 233431200
42 )
43
44 var (
45 beforeExp = func() time.Time { return time.Unix(expiry-1, 0) }
46 afterExp = func() time.Time { return time.Unix(expiry+1, 0) }
47 )
48
49 func TestValidateRS256(t *testing.T) {
50 idToken, pk := createRS256JWT(t)
51 tests := []struct {
52 name string
53 keyID string
54 n *big.Int
55 e int
56 nowFunc func() time.Time
57 wantErr bool
58 }{
59 {
60 name: "works",
61 keyID: keyID,
62 n: pk.N,
63 e: pk.E,
64 nowFunc: beforeExp,
65 wantErr: false,
66 },
67 {
68 name: "no matching key",
69 keyID: "5678",
70 n: pk.N,
71 e: pk.E,
72 nowFunc: beforeExp,
73 wantErr: true,
74 },
75 {
76 name: "sig does not match",
77 keyID: keyID,
78 n: new(big.Int).SetBytes([]byte("42")),
79 e: 42,
80 nowFunc: beforeExp,
81 wantErr: true,
82 },
83 {
84 name: "token expired",
85 keyID: keyID,
86 n: pk.N,
87 e: pk.E,
88 nowFunc: afterExp,
89 wantErr: true,
90 },
91 }
92
93 for _, tt := range tests {
94 t.Run(tt.name, func(t *testing.T) {
95 client := &http.Client{
96 Transport: RoundTripFn(func(req *http.Request) *http.Response {
97 cr := certResponse{
98 Keys: []jwk{
99 {
100 Kid: tt.keyID,
101 N: base64.RawURLEncoding.EncodeToString(tt.n.Bytes()),
102 E: base64.RawURLEncoding.EncodeToString(new(big.Int).SetInt64(int64(tt.e)).Bytes()),
103 },
104 },
105 }
106 b, err := json.Marshal(&cr)
107 if err != nil {
108 t.Fatalf("unable to marshal response: %v", err)
109 }
110 return &http.Response{
111 StatusCode: 200,
112 Body: io.NopCloser(bytes.NewReader(b)),
113 Header: make(http.Header),
114 }
115 }),
116 }
117 oldNow := now
118 defer func() { now = oldNow }()
119 now = tt.nowFunc
120
121 v, err := NewValidator(&ValidatorOptions{
122 Client: client,
123 })
124 if err != nil {
125 t.Fatalf("NewValidator(...) = %q, want nil", err)
126 }
127 payload, err := v.Validate(context.Background(), idToken, testAudience)
128 if tt.wantErr && err != nil {
129
130 return
131 }
132 if !tt.wantErr && err != nil {
133 t.Fatalf("Validate(ctx, %s, %s): got err %q, want nil", idToken, testAudience, err)
134 }
135 if tt.wantErr && err == nil {
136 t.Fatalf("Validate(ctx, %s, %s): got nil err, want err", idToken, testAudience)
137 }
138 if payload == nil {
139 t.Fatalf("Got nil payload, err: %v", err)
140 }
141 if payload.Audience != testAudience {
142 t.Fatalf("Validate(ctx, %s, %s): got %v, want %v", idToken, testAudience, payload.Audience, testAudience)
143 }
144 if len(payload.Claims) == 0 {
145 t.Fatalf("Validate(ctx, %s, %s): missing Claims map. payload.Claims = %+v", idToken, testAudience, payload.Claims)
146 }
147 if got, ok := payload.Claims["aud"]; !ok {
148 t.Fatalf("Validate(ctx, %s, %s): missing aud claim. payload.Claims = %+v", idToken, testAudience, payload.Claims)
149 } else {
150 got, ok := got.(string)
151 if !ok {
152 t.Fatalf("Validate(ctx, %s, %s): aud wasn't a string. payload.Claims = %+v", idToken, testAudience, payload.Claims)
153 }
154 if got != testAudience {
155 t.Fatalf("Validate(ctx, %s, %s): Payload[aud] want %v got %v", idToken, testAudience, testAudience, got)
156 }
157 }
158 })
159 }
160 }
161
162 func TestValidateES256(t *testing.T) {
163 idToken, pk := createES256JWT(t)
164 tests := []struct {
165 name string
166 keyID string
167 x *big.Int
168 y *big.Int
169 nowFunc func() time.Time
170 wantErr bool
171 }{
172 {
173 name: "works",
174 keyID: keyID,
175 x: pk.X,
176 y: pk.Y,
177 nowFunc: beforeExp,
178 wantErr: false,
179 },
180 {
181 name: "no matching key",
182 keyID: "5678",
183 x: pk.X,
184 y: pk.Y,
185 nowFunc: beforeExp,
186 wantErr: true,
187 },
188 {
189 name: "sig does not match",
190 keyID: keyID,
191 x: new(big.Int),
192 y: new(big.Int),
193 nowFunc: beforeExp,
194 wantErr: true,
195 },
196 {
197 name: "token expired",
198 keyID: keyID,
199 x: pk.X,
200 y: pk.Y,
201 nowFunc: afterExp,
202 wantErr: true,
203 },
204 }
205 for _, tt := range tests {
206 t.Run(tt.name, func(t *testing.T) {
207 client := &http.Client{
208 Transport: RoundTripFn(func(req *http.Request) *http.Response {
209 cr := certResponse{
210 Keys: []jwk{
211 {
212 Kid: tt.keyID,
213 X: base64.RawURLEncoding.EncodeToString(tt.x.Bytes()),
214 Y: base64.RawURLEncoding.EncodeToString(tt.y.Bytes()),
215 },
216 },
217 }
218 b, err := json.Marshal(&cr)
219 if err != nil {
220 t.Fatalf("unable to marshal response: %v", err)
221 }
222 return &http.Response{
223 StatusCode: 200,
224 Body: io.NopCloser(bytes.NewReader(b)),
225 Header: make(http.Header),
226 }
227 }),
228 }
229 oldNow := now
230 defer func() { now = oldNow }()
231 now = tt.nowFunc
232
233 v, err := NewValidator(&ValidatorOptions{
234 Client: client,
235 })
236 if err != nil {
237 t.Fatalf("NewValidator(...) = %q, want nil", err)
238 }
239 payload, err := v.Validate(context.Background(), idToken, testAudience)
240 if !tt.wantErr && err != nil {
241 t.Fatalf("Validate(ctx, %s, %s) = %q, want nil", idToken, testAudience, err)
242 }
243 if !tt.wantErr && payload.Audience != testAudience {
244 t.Fatalf("got %v, want %v", payload.Audience, testAudience)
245 }
246 })
247 }
248 }
249
250 func TestParsePayload(t *testing.T) {
251 idToken, _ := createRS256JWT(t)
252 tests := []struct {
253 name string
254 token string
255 wantPayloadAudience string
256 wantErr bool
257 }{{
258 name: "valid token",
259 token: idToken,
260 wantPayloadAudience: testAudience,
261 }, {
262 name: "unparseable token",
263 token: "aaa.bbb.ccc",
264 wantErr: true,
265 }}
266
267 for _, tt := range tests {
268 t.Run(tt.name, func(t *testing.T) {
269 payload, err := ParsePayload(tt.token)
270 gotErr := err != nil
271 if gotErr != tt.wantErr {
272 t.Errorf("ParsePayload(%q) got error %v, wantErr = %v", tt.token, err, tt.wantErr)
273 }
274 if tt.wantPayloadAudience != "" {
275 if payload == nil || payload.Audience != tt.wantPayloadAudience {
276 t.Errorf("ParsePayload(%q) got payload %+v, want payload with audience = %q", tt.token, payload, tt.wantPayloadAudience)
277 }
278 }
279 })
280 }
281 }
282
283 func createES256JWT(t *testing.T) (string, ecdsa.PublicKey) {
284 t.Helper()
285 header, claims := commonToken(t, "ES256")
286 privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
287 if err != nil {
288 t.Fatalf("unable to generate key: %v", err)
289 }
290 signedContent := header + "." + claims
291 hashed := sha256.Sum256([]byte(signedContent))
292 hash := hashed[:]
293 r, s, err := ecdsa.Sign(rand.Reader, privateKey, hash)
294 if err != nil {
295 t.Fatalf("unable to sign content: %v", err)
296 }
297 rb := r.Bytes()
298 lPadded := make([]byte, es256KeySize)
299 copy(lPadded[es256KeySize-len(rb):], rb)
300 var sig []byte
301 sig = append(sig, lPadded...)
302 sig = append(sig, s.Bytes()...)
303 signature := base64.RawURLEncoding.EncodeToString(sig)
304 return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey
305 }
306
307 func createRS256JWT(t *testing.T) (string, rsa.PublicKey) {
308 t.Helper()
309 header, claims := commonToken(t, jwt.HeaderAlgRSA256)
310 privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
311 if err != nil {
312 t.Fatalf("unable to generate key: %v", err)
313 }
314 signedContent := header + "." + claims
315 hashed := sha256.Sum256([]byte(signedContent))
316 hash := hashed[:]
317 sig, err := rsa.SignPKCS1v15(rand.Reader, privateKey, crypto.SHA256, hash)
318 if err != nil {
319 t.Fatalf("unable to sign content: %v", err)
320 }
321 signature := base64.RawURLEncoding.EncodeToString(sig)
322 return fmt.Sprintf("%s.%s.%s", header, claims, signature), privateKey.PublicKey
323 }
324
325
326 func commonToken(t *testing.T, alg string) (string, string) {
327 t.Helper()
328 header := jwt.Header{
329 KeyID: keyID,
330 Algorithm: alg,
331 Type: jwt.HeaderType,
332 }
333 payload := Payload{
334 Issuer: "example.com",
335 Audience: testAudience,
336 Expires: expiry,
337 }
338
339 hb, err := json.Marshal(&header)
340 if err != nil {
341 t.Fatalf("unable to marshall header: %v", err)
342 }
343 pb, err := json.Marshal(&payload)
344 if err != nil {
345 t.Fatalf("unable to marshall payload: %v", err)
346 }
347 eb := base64.RawURLEncoding.EncodeToString(hb)
348 ep := base64.RawURLEncoding.EncodeToString(pb)
349 return eb, ep
350 }
351
352 type RoundTripFn func(req *http.Request) *http.Response
353
354 func (f RoundTripFn) RoundTrip(req *http.Request) (*http.Response, error) { return f(req), nil }
355
View as plain text