1 package token
2
3 import (
4 "crypto"
5 "crypto/rand"
6 "crypto/x509"
7 "encoding/base64"
8 "encoding/json"
9 "encoding/pem"
10 "fmt"
11 "io/ioutil"
12 "net/http"
13 "os"
14 "strings"
15 "testing"
16 "time"
17
18 "github.com/docker/distribution/context"
19 "github.com/docker/distribution/registry/auth"
20 "github.com/docker/libtrust"
21 )
22
23 func makeRootKeys(numKeys int) ([]libtrust.PrivateKey, error) {
24 keys := make([]libtrust.PrivateKey, 0, numKeys)
25
26 for i := 0; i < numKeys; i++ {
27 key, err := libtrust.GenerateECP256PrivateKey()
28 if err != nil {
29 return nil, err
30 }
31 keys = append(keys, key)
32 }
33
34 return keys, nil
35 }
36
37 func makeSigningKeyWithChain(rootKey libtrust.PrivateKey, depth int) (libtrust.PrivateKey, error) {
38 if depth == 0 {
39
40 return rootKey, nil
41 }
42
43 var (
44 x5c = make([]string, depth)
45 parentKey = rootKey
46 key libtrust.PrivateKey
47 cert *x509.Certificate
48 err error
49 )
50
51 for depth > 0 {
52 if key, err = libtrust.GenerateECP256PrivateKey(); err != nil {
53 return nil, err
54 }
55
56 if cert, err = libtrust.GenerateCACert(parentKey, key); err != nil {
57 return nil, err
58 }
59
60 depth--
61 x5c[depth] = base64.StdEncoding.EncodeToString(cert.Raw)
62 parentKey = key
63 }
64
65 key.AddExtendedField("x5c", x5c)
66
67 return key, nil
68 }
69
70 func makeRootCerts(rootKeys []libtrust.PrivateKey) ([]*x509.Certificate, error) {
71 certs := make([]*x509.Certificate, 0, len(rootKeys))
72
73 for _, key := range rootKeys {
74 cert, err := libtrust.GenerateCACert(key, key)
75 if err != nil {
76 return nil, err
77 }
78 certs = append(certs, cert)
79 }
80
81 return certs, nil
82 }
83
84 func makeTrustedKeyMap(rootKeys []libtrust.PrivateKey) map[string]libtrust.PublicKey {
85 trustedKeys := make(map[string]libtrust.PublicKey, len(rootKeys))
86
87 for _, key := range rootKeys {
88 trustedKeys[key.KeyID()] = key.PublicKey()
89 }
90
91 return trustedKeys
92 }
93
94 func makeTestToken(issuer, audience string, access []*ResourceActions, rootKey libtrust.PrivateKey, depth int, now time.Time, exp time.Time) (*Token, error) {
95 signingKey, err := makeSigningKeyWithChain(rootKey, depth)
96 if err != nil {
97 return nil, fmt.Errorf("unable to make signing key with chain: %s", err)
98 }
99
100 var rawJWK json.RawMessage
101 rawJWK, err = signingKey.PublicKey().MarshalJSON()
102 if err != nil {
103 return nil, fmt.Errorf("unable to marshal signing key to JSON: %s", err)
104 }
105
106 joseHeader := &Header{
107 Type: "JWT",
108 SigningAlg: "ES256",
109 RawJWK: &rawJWK,
110 }
111
112 randomBytes := make([]byte, 15)
113 if _, err = rand.Read(randomBytes); err != nil {
114 return nil, fmt.Errorf("unable to read random bytes for jwt id: %s", err)
115 }
116
117 claimSet := &ClaimSet{
118 Issuer: issuer,
119 Subject: "foo",
120 Audience: audience,
121 Expiration: exp.Unix(),
122 NotBefore: now.Unix(),
123 IssuedAt: now.Unix(),
124 JWTID: base64.URLEncoding.EncodeToString(randomBytes),
125 Access: access,
126 }
127
128 var joseHeaderBytes, claimSetBytes []byte
129
130 if joseHeaderBytes, err = json.Marshal(joseHeader); err != nil {
131 return nil, fmt.Errorf("unable to marshal jose header: %s", err)
132 }
133 if claimSetBytes, err = json.Marshal(claimSet); err != nil {
134 return nil, fmt.Errorf("unable to marshal claim set: %s", err)
135 }
136
137 encodedJoseHeader := joseBase64UrlEncode(joseHeaderBytes)
138 encodedClaimSet := joseBase64UrlEncode(claimSetBytes)
139 encodingToSign := fmt.Sprintf("%s.%s", encodedJoseHeader, encodedClaimSet)
140
141 var signatureBytes []byte
142 if signatureBytes, _, err = signingKey.Sign(strings.NewReader(encodingToSign), crypto.SHA256); err != nil {
143 return nil, fmt.Errorf("unable to sign jwt payload: %s", err)
144 }
145
146 signature := joseBase64UrlEncode(signatureBytes)
147 tokenString := fmt.Sprintf("%s.%s", encodingToSign, signature)
148
149 return NewToken(tokenString)
150 }
151
152
153
154
155 func TestTokenVerify(t *testing.T) {
156 var (
157 numTokens = 4
158 issuer = "test-issuer"
159 audience = "test-audience"
160 access = []*ResourceActions{
161 {
162 Type: "repository",
163 Name: "foo/bar",
164 Actions: []string{"pull", "push"},
165 },
166 }
167 )
168
169 rootKeys, err := makeRootKeys(numTokens)
170 if err != nil {
171 t.Fatal(err)
172 }
173
174 rootCerts, err := makeRootCerts(rootKeys)
175 if err != nil {
176 t.Fatal(err)
177 }
178
179 rootPool := x509.NewCertPool()
180 for _, rootCert := range rootCerts {
181 rootPool.AddCert(rootCert)
182 }
183
184 trustedKeys := makeTrustedKeyMap(rootKeys)
185
186 tokens := make([]*Token, 0, numTokens)
187
188 for i := 0; i < numTokens; i++ {
189 token, err := makeTestToken(issuer, audience, access, rootKeys[i], i, time.Now(), time.Now().Add(5*time.Minute))
190 if err != nil {
191 t.Fatal(err)
192 }
193 tokens = append(tokens, token)
194 }
195
196 verifyOps := VerifyOptions{
197 TrustedIssuers: []string{issuer},
198 AcceptedAudiences: []string{audience},
199 Roots: rootPool,
200 TrustedKeys: trustedKeys,
201 }
202
203 for _, token := range tokens {
204 if err := token.Verify(verifyOps); err != nil {
205 t.Fatal(err)
206 }
207 }
208 }
209
210
211
212 func TestLeeway(t *testing.T) {
213 var (
214 issuer = "test-issuer"
215 audience = "test-audience"
216 access = []*ResourceActions{
217 {
218 Type: "repository",
219 Name: "foo/bar",
220 Actions: []string{"pull", "push"},
221 },
222 }
223 )
224
225 rootKeys, err := makeRootKeys(1)
226 if err != nil {
227 t.Fatal(err)
228 }
229
230 trustedKeys := makeTrustedKeyMap(rootKeys)
231
232 verifyOps := VerifyOptions{
233 TrustedIssuers: []string{issuer},
234 AcceptedAudiences: []string{audience},
235 Roots: nil,
236 TrustedKeys: trustedKeys,
237 }
238
239
240 futureNow := time.Now().Add(time.Duration(5) * time.Second)
241 token, err := makeTestToken(issuer, audience, access, rootKeys[0], 0, futureNow, futureNow.Add(5*time.Minute))
242 if err != nil {
243 t.Fatal(err)
244 }
245
246 if err := token.Verify(verifyOps); err != nil {
247 t.Fatal(err)
248 }
249
250
251 futureNow = time.Now().Add(time.Duration(61) * time.Second)
252 token, err = makeTestToken(issuer, audience, access, rootKeys[0], 0, futureNow, futureNow.Add(5*time.Minute))
253 if err != nil {
254 t.Fatal(err)
255 }
256
257 if err = token.Verify(verifyOps); err == nil {
258 t.Fatal("Verification should fail for token with nbf in the future outside leeway")
259 }
260
261
262 token, err = makeTestToken(issuer, audience, access, rootKeys[0], 0, time.Now(), time.Now().Add(-59*time.Second))
263 if err != nil {
264 t.Fatal(err)
265 }
266
267 if err = token.Verify(verifyOps); err != nil {
268 t.Fatal(err)
269 }
270
271
272 token, err = makeTestToken(issuer, audience, access, rootKeys[0], 0, time.Now(), time.Now().Add(-60*time.Second))
273 if err != nil {
274 t.Fatal(err)
275 }
276
277 if err = token.Verify(verifyOps); err == nil {
278 t.Fatal("Verification should fail for token with exp in the future outside leeway")
279 }
280 }
281
282 func writeTempRootCerts(rootKeys []libtrust.PrivateKey) (filename string, err error) {
283 rootCerts, err := makeRootCerts(rootKeys)
284 if err != nil {
285 return "", err
286 }
287
288 tempFile, err := ioutil.TempFile("", "rootCertBundle")
289 if err != nil {
290 return "", err
291 }
292 defer tempFile.Close()
293
294 for _, cert := range rootCerts {
295 if err = pem.Encode(tempFile, &pem.Block{
296 Type: "CERTIFICATE",
297 Bytes: cert.Raw,
298 }); err != nil {
299 os.Remove(tempFile.Name())
300 return "", err
301 }
302 }
303
304 return tempFile.Name(), nil
305 }
306
307
308
309
310
311
312
313
314 func TestAccessController(t *testing.T) {
315
316 rootKeys, err := makeRootKeys(2)
317 if err != nil {
318 t.Fatal(err)
319 }
320
321 rootCertBundleFilename, err := writeTempRootCerts(rootKeys[:1])
322 if err != nil {
323 t.Fatal(err)
324 }
325 defer os.Remove(rootCertBundleFilename)
326
327 realm := "https://auth.example.com/token/"
328 issuer := "test-issuer.example.com"
329 service := "test-service.example.com"
330
331 options := map[string]interface{}{
332 "realm": realm,
333 "issuer": issuer,
334 "service": service,
335 "rootcertbundle": rootCertBundleFilename,
336 "autoredirect": false,
337 }
338
339 accessController, err := newAccessController(options)
340 if err != nil {
341 t.Fatal(err)
342 }
343
344
345 req, err := http.NewRequest("GET", "http://example.com/foo", nil)
346 if err != nil {
347 t.Fatal(err)
348 }
349
350 testAccess := auth.Access{
351 Resource: auth.Resource{
352 Type: "foo",
353 Name: "bar",
354 },
355 Action: "baz",
356 }
357
358 ctx := context.WithRequest(context.Background(), req)
359 authCtx, err := accessController.Authorized(ctx, testAccess)
360 challenge, ok := err.(auth.Challenge)
361 if !ok {
362 t.Fatal("accessController did not return a challenge")
363 }
364
365 if challenge.Error() != ErrTokenRequired.Error() {
366 t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
367 }
368
369 if authCtx != nil {
370 t.Fatalf("expected nil auth context but got %s", authCtx)
371 }
372
373
374 token, err := makeTestToken(
375 issuer, service,
376 []*ResourceActions{{
377 Type: testAccess.Type,
378 Name: testAccess.Name,
379 Actions: []string{testAccess.Action},
380 }},
381 rootKeys[1], 1, time.Now(), time.Now().Add(5*time.Minute),
382 )
383 if err != nil {
384 t.Fatal(err)
385 }
386
387 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
388
389 authCtx, err = accessController.Authorized(ctx, testAccess)
390 challenge, ok = err.(auth.Challenge)
391 if !ok {
392 t.Fatal("accessController did not return a challenge")
393 }
394
395 if challenge.Error() != ErrInvalidToken.Error() {
396 t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrTokenRequired)
397 }
398
399 if authCtx != nil {
400 t.Fatalf("expected nil auth context but got %s", authCtx)
401 }
402
403
404 token, err = makeTestToken(
405 issuer, service,
406 []*ResourceActions{},
407 rootKeys[0], 1, time.Now(), time.Now().Add(5*time.Minute),
408 )
409 if err != nil {
410 t.Fatal(err)
411 }
412
413 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
414
415 authCtx, err = accessController.Authorized(ctx, testAccess)
416 challenge, ok = err.(auth.Challenge)
417 if !ok {
418 t.Fatal("accessController did not return a challenge")
419 }
420
421 if challenge.Error() != ErrInsufficientScope.Error() {
422 t.Fatalf("accessControler did not get expected error - got %s - expected %s", challenge, ErrInsufficientScope)
423 }
424
425 if authCtx != nil {
426 t.Fatalf("expected nil auth context but got %s", authCtx)
427 }
428
429
430 token, err = makeTestToken(
431 issuer, service,
432 []*ResourceActions{{
433 Type: testAccess.Type,
434 Name: testAccess.Name,
435 Actions: []string{testAccess.Action},
436 }},
437 rootKeys[0], 1, time.Now(), time.Now().Add(5*time.Minute),
438 )
439 if err != nil {
440 t.Fatal(err)
441 }
442
443 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
444
445 authCtx, err = accessController.Authorized(ctx, testAccess)
446 if err != nil {
447 t.Fatalf("accessController returned unexpected error: %s", err)
448 }
449
450 userInfo, ok := authCtx.Value(auth.UserKey).(auth.UserInfo)
451 if !ok {
452 t.Fatal("token accessController did not set auth.user context")
453 }
454
455 if userInfo.Name != "foo" {
456 t.Fatalf("expected user name %q, got %q", "foo", userInfo.Name)
457 }
458
459
460 token, err = makeTestToken(
461 issuer, service,
462 []*ResourceActions{{
463 Type: testAccess.Type,
464 Name: testAccess.Name,
465 Actions: []string{"*"},
466 }},
467 rootKeys[0], 1, time.Now(), time.Now().Add(5*time.Minute),
468 )
469 if err != nil {
470 t.Fatal(err)
471 }
472
473 req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.compactRaw()))
474
475 _, err = accessController.Authorized(ctx, testAccess)
476 if err != nil {
477 t.Fatalf("accessController returned unexpected error: %s", err)
478 }
479 }
480
481
482
483 func TestNewAccessControllerPemBlock(t *testing.T) {
484 rootKeys, err := makeRootKeys(2)
485 if err != nil {
486 t.Fatal(err)
487 }
488
489 rootCertBundleFilename, err := writeTempRootCerts(rootKeys)
490 if err != nil {
491 t.Fatal(err)
492 }
493 defer os.Remove(rootCertBundleFilename)
494
495
496 file, err := os.OpenFile(rootCertBundleFilename, os.O_WRONLY|os.O_APPEND, 0666)
497 if err != nil {
498 t.Fatal(err)
499 }
500 keyBlock, err := rootKeys[0].PEMBlock()
501 if err != nil {
502 t.Fatal(err)
503 }
504 err = pem.Encode(file, keyBlock)
505 if err != nil {
506 t.Fatal(err)
507 }
508 err = file.Close()
509 if err != nil {
510 t.Fatal(err)
511 }
512
513 realm := "https://auth.example.com/token/"
514 issuer := "test-issuer.example.com"
515 service := "test-service.example.com"
516
517 options := map[string]interface{}{
518 "realm": realm,
519 "issuer": issuer,
520 "service": service,
521 "rootcertbundle": rootCertBundleFilename,
522 "autoredirect": false,
523 }
524
525 ac, err := newAccessController(options)
526 if err != nil {
527 t.Fatal(err)
528 }
529
530 if len(ac.(*accessController).rootCerts.Subjects()) != 2 {
531 t.Fatal("accessController has the wrong number of certificates")
532 }
533 }
534
View as plain text