1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package transport
16
17 import (
18 "context"
19 "fmt"
20 "net/http"
21 "net/http/httptest"
22 "net/url"
23 "strings"
24 "testing"
25
26 "github.com/google/go-containerregistry/pkg/authn"
27 "github.com/google/go-containerregistry/pkg/name"
28 )
29
30 func TestBearerRefresh(t *testing.T) {
31 expectedToken := "Sup3rDup3rS3cr3tz"
32 expectedScope := "this-is-your-scope"
33 expectedService := "my-service.io"
34
35 cases := []struct {
36 tokenKey string
37 wantErr bool
38 }{{
39 tokenKey: "token",
40 wantErr: false,
41 }, {
42 tokenKey: "access_token",
43 wantErr: false,
44 }, {
45 tokenKey: "tolkien",
46 wantErr: true,
47 }}
48
49 for _, tc := range cases {
50 t.Run(tc.tokenKey, func(t *testing.T) {
51 server := httptest.NewServer(
52 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
53 hdr := r.Header.Get("Authorization")
54 if !strings.HasPrefix(hdr, "Basic ") {
55 t.Errorf("Header.Get(Authorization); got %v, want Basic prefix", hdr)
56 }
57 if got, want := r.FormValue("scope"), expectedScope; got != want {
58 t.Errorf("FormValue(scope); got %v, want %v", got, want)
59 }
60 if got, want := r.FormValue("service"), expectedService; got != want {
61 t.Errorf("FormValue(service); got %v, want %v", got, want)
62 }
63 w.Write([]byte(fmt.Sprintf(`{%q: %q}`, tc.tokenKey, expectedToken)))
64 }))
65 defer server.Close()
66
67 basic := &authn.Basic{Username: "foo", Password: "bar"}
68 registry, err := name.NewRegistry(expectedService, name.WeakValidation)
69 if err != nil {
70 t.Errorf("Unexpected error during NewRegistry: %v", err)
71 }
72
73 bt := &bearerTransport{
74 inner: http.DefaultTransport,
75 basic: basic,
76 registry: registry,
77 realm: server.URL,
78 scopes: []string{expectedScope},
79 service: expectedService,
80 scheme: "http",
81 }
82
83 if err := bt.refresh(context.Background()); (err != nil) != tc.wantErr {
84 t.Errorf("refresh() = %v", err)
85 }
86 })
87 }
88 }
89
90 func TestBearerTransport(t *testing.T) {
91 expectedToken := "sdkjhfskjdhfkjshdf"
92
93 blobServer := httptest.NewServer(
94 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
95
96 if got := r.Header.Get("Authorization"); got != "" {
97 t.Errorf("Header.Get(Authorization); got %v, want empty string", got)
98 }
99 w.WriteHeader(http.StatusOK)
100 }))
101 defer blobServer.Close()
102
103 server := httptest.NewServer(
104 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105 if got, want := r.Header.Get("Authorization"), "Bearer "+expectedToken; got != want {
106 t.Errorf("Header.Get(Authorization); got %v, want %v", got, want)
107 }
108 if r.URL.Path == "/v2/auth" {
109 http.Redirect(w, r, "/redirect", http.StatusMovedPermanently)
110 return
111 }
112 if strings.Contains(r.URL.Path, "blobs") {
113 http.Redirect(w, r, blobServer.URL, http.StatusFound)
114 return
115 }
116 w.WriteHeader(http.StatusOK)
117 }))
118 defer server.Close()
119
120 u, err := url.Parse(server.URL)
121 if err != nil {
122 t.Errorf("Unexpected error during url.Parse: %v", err)
123 }
124 registry, err := name.NewRegistry(u.Host, name.WeakValidation)
125 if err != nil {
126 t.Errorf("Unexpected error during NewRegistry: %v", err)
127 }
128
129 client := http.Client{Transport: &bearerTransport{
130 inner: &http.Transport{},
131 bearer: authn.AuthConfig{RegistryToken: expectedToken},
132 registry: registry,
133 scheme: "http",
134 }}
135
136 _, err = client.Get(fmt.Sprintf("http://%s/v2/auth", u.Host))
137 if err != nil {
138 t.Errorf("Unexpected error during Get: %v", err)
139 }
140
141 _, err = client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
142 if err != nil {
143 t.Errorf("Unexpected error during Get: %v", err)
144 }
145 }
146
147 func TestBearerTransportTokenRefresh(t *testing.T) {
148 initialToken := "foo"
149 refreshedToken := "bar"
150
151 server := httptest.NewServer(
152 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
153 hdr := r.Header.Get("Authorization")
154 if hdr == "Bearer "+refreshedToken {
155 w.WriteHeader(http.StatusOK)
156 return
157 }
158 if strings.HasPrefix(hdr, "Basic ") {
159 w.Write([]byte(fmt.Sprintf(`{"token": %q}`, refreshedToken)))
160 }
161
162 w.Header().Set("WWW-Authenticate", "scope=foo")
163 w.WriteHeader(http.StatusUnauthorized)
164 }))
165 defer server.Close()
166
167 u, err := url.Parse(server.URL)
168 if err != nil {
169 t.Fatal(err)
170 }
171 registry, err := name.NewRegistry(u.Host, name.WeakValidation)
172 if err != nil {
173 t.Fatalf("Unexpected error during NewRegistry: %v", err)
174 }
175
176
177 transport := &bearerTransport{
178 inner: http.DefaultTransport,
179 bearer: authn.AuthConfig{RegistryToken: initialToken},
180 basic: &authn.Basic{Username: "foo", Password: "bar"},
181 registry: registry,
182 realm: server.URL,
183 scheme: "http",
184 }
185 client := http.Client{Transport: transport}
186
187 res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
188 if err != nil {
189 t.Errorf("Unexpected error during client.Get: %v", err)
190 return
191 }
192 if res.StatusCode != http.StatusOK {
193 t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
194 }
195 if got, want := transport.bearer.RegistryToken, refreshedToken; got != want {
196 t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
197 }
198
199
200 transport.bearer = authn.AuthConfig{RegistryToken: initialToken}
201 transport.basic = &authn.Bearer{Token: refreshedToken}
202 client = http.Client{Transport: transport}
203
204 res, err = client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
205 if err != nil {
206 t.Errorf("Unexpected error during client.Get: %v", err)
207 return
208 }
209 if res.StatusCode != http.StatusOK {
210 t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
211 }
212 if got, want := transport.bearer.RegistryToken, refreshedToken; got != want {
213 t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
214 }
215 }
216
217 func TestBearerTransportOauthRefresh(t *testing.T) {
218 initialToken := "foo"
219 accessToken := "bar"
220 refreshToken := "baz"
221
222 server := httptest.NewServer(
223 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
224 if r.Method == http.MethodPost {
225 if err := r.ParseForm(); err != nil {
226 t.Fatal(err)
227 }
228 if it := r.FormValue("refresh_token"); it != initialToken {
229 t.Errorf("want %s got %s", initialToken, it)
230 }
231 w.WriteHeader(http.StatusOK)
232 w.Write([]byte(fmt.Sprintf(`{"access_token": %q, "refresh_token": %q}`, accessToken, refreshToken)))
233 return
234 }
235
236 hdr := r.Header.Get("Authorization")
237 if hdr == "Bearer "+accessToken {
238 w.WriteHeader(http.StatusOK)
239 return
240 }
241
242 w.Header().Set("WWW-Authenticate", "scope=foo")
243 w.WriteHeader(http.StatusUnauthorized)
244 }))
245 defer server.Close()
246
247 u, err := url.Parse(server.URL)
248 if err != nil {
249 t.Fatal(err)
250 }
251 registry, err := name.NewRegistry(u.Host, name.WeakValidation)
252 if err != nil {
253 t.Errorf("Unexpected error during NewRegistry: %v", err)
254 }
255
256 transport := &bearerTransport{
257 inner: http.DefaultTransport,
258 basic: authn.FromConfig(authn.AuthConfig{IdentityToken: initialToken}),
259 registry: registry,
260 realm: server.URL,
261 scheme: "http",
262 scopes: []string{"myscope"},
263 service: u.Host,
264 }
265 client := http.Client{Transport: transport}
266
267 res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
268 if err != nil {
269 t.Fatalf("Unexpected error during client.Get: %v", err)
270 }
271 if res.StatusCode != http.StatusOK {
272 t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
273 }
274 if want, got := transport.bearer.RegistryToken, accessToken; want != got {
275 t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
276 }
277 basicAuthConfig, err := transport.basic.Authorization()
278 if err != nil {
279 t.Fatal(err)
280 }
281 if got, want := basicAuthConfig.IdentityToken, refreshToken; got != want {
282 t.Errorf("Expected Basic IdentityToken to be refreshed, got %v, want %v", got, want)
283 }
284 }
285
286 func TestBearerTransportOauth404Fallback(t *testing.T) {
287 basicAuth := "basic_auth"
288 identityToken := "identity_token"
289 accessToken := "access_token"
290
291 server := httptest.NewServer(
292 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
293 if r.Method == http.MethodPost {
294 w.WriteHeader(http.StatusNotFound)
295 }
296
297 hdr := r.Header.Get("Authorization")
298 if hdr == "Basic "+basicAuth {
299 w.WriteHeader(http.StatusOK)
300 w.Write([]byte(fmt.Sprintf(`{"access_token": %q}`, accessToken)))
301 }
302 if hdr == "Bearer "+accessToken {
303 w.WriteHeader(http.StatusOK)
304 return
305 }
306
307 w.Header().Set("WWW-Authenticate", "scope=foo")
308 w.WriteHeader(http.StatusUnauthorized)
309 }))
310 defer server.Close()
311
312 u, err := url.Parse(server.URL)
313 if err != nil {
314 t.Fatal(err)
315 }
316 registry, err := name.NewRegistry(u.Host, name.WeakValidation)
317 if err != nil {
318 t.Errorf("Unexpected error during NewRegistry: %v", err)
319 }
320
321 transport := &bearerTransport{
322 inner: http.DefaultTransport,
323 basic: authn.FromConfig(authn.AuthConfig{
324 IdentityToken: identityToken,
325 Auth: basicAuth,
326 }),
327 registry: registry,
328 realm: server.URL,
329 scheme: "http",
330 scopes: []string{"myscope"},
331 service: u.Host,
332 }
333 client := http.Client{Transport: transport}
334
335 res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
336 if err != nil {
337 t.Fatalf("Unexpected error during client.Get: %v", err)
338 }
339 if res.StatusCode != http.StatusOK {
340 t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
341 }
342 if got, want := transport.bearer.RegistryToken, accessToken; got != want {
343 t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
344 }
345 }
346
347 type recorder struct {
348 reqs []*http.Request
349 resp *http.Response
350 err error
351 }
352
353 func newRecorder(resp *http.Response, err error) *recorder {
354 return &recorder{
355 reqs: []*http.Request{},
356 resp: resp,
357 err: err,
358 }
359 }
360
361 func (r *recorder) RoundTrip(in *http.Request) (*http.Response, error) {
362 r.reqs = append(r.reqs, in)
363 return r.resp, r.err
364 }
365
366 func TestSchemeOverride(t *testing.T) {
367
368 cannedResponse := http.Response{
369 Status: http.StatusText(http.StatusOK),
370 StatusCode: http.StatusOK,
371 }
372 rec := newRecorder(&cannedResponse, nil)
373 registry, err := name.NewRegistry("example.com")
374 if err != nil {
375 t.Fatalf("Unexpected error during NewRegistry: %v", err)
376 }
377 st := &schemeTransport{
378 inner: rec,
379 registry: registry,
380 scheme: "http",
381 }
382
383
384
385 tests := []struct {
386 url string
387 wantScheme string
388 }{{
389 url: "https://example.com",
390 wantScheme: "http",
391 }, {
392 url: "https://token.example.com",
393 wantScheme: "https",
394 }}
395
396 for i, tt := range tests {
397 req, err := http.NewRequest("GET", tt.url, nil)
398 if err != nil {
399 t.Fatalf("Unexpected error during NewRequest: %v", err)
400 }
401
402 if _, err := st.RoundTrip(req); err != nil {
403 t.Fatalf("Unexpected error during RoundTrip: %v", err)
404 }
405
406 if got, want := rec.reqs[i].URL.Scheme, tt.wantScheme; got != want {
407 t.Errorf("Wrong scheme: wanted %v, got %v", want, got)
408 }
409 }
410 }
411
412 func TestCanonicalAddressResolution(t *testing.T) {
413 registry, err := name.NewRegistry("does-not-matter", name.WeakValidation)
414 if err != nil {
415 t.Errorf("Unexpected error during NewRegistry: %v", err)
416 }
417
418 tests := []struct {
419 registry name.Registry
420 scheme string
421 address string
422 want string
423 }{{
424 registry: registry,
425 scheme: "http",
426 address: "registry.example.com",
427 want: "registry.example.com:80",
428 }, {
429 registry: registry,
430 scheme: "http",
431 address: "registry.example.com:12345",
432 want: "registry.example.com:12345",
433 }, {
434 registry: registry,
435 scheme: "https",
436 address: "registry.example.com",
437 want: "registry.example.com:443",
438 }, {
439 registry: registry,
440 scheme: "https",
441 address: "registry.example.com:12345",
442 want: "registry.example.com:12345",
443 }, {
444 registry: registry,
445 scheme: "http",
446 address: "registry.example.com:",
447 want: "registry.example.com:80",
448 }, {
449 registry: registry,
450 scheme: "https",
451 address: "registry.example.com:",
452 want: "registry.example.com:443",
453 }, {
454 registry: registry,
455 scheme: "http",
456 address: "2001:db8::1",
457 want: "[2001:db8::1]:80",
458 }, {
459 registry: registry,
460 scheme: "https",
461 address: "2001:db8::1",
462 want: "[2001:db8::1]:443",
463 }, {
464 registry: registry,
465 scheme: "http",
466 address: "[2001:db8::1]:12345",
467 want: "[2001:db8::1]:12345",
468 }, {
469 registry: registry,
470 scheme: "https",
471 address: "[2001:db8::1]:12345",
472 want: "[2001:db8::1]:12345",
473 }, {
474 registry: registry,
475 scheme: "http",
476 address: "[2001:db8::1]:",
477 want: "[2001:db8::1]:80",
478 }, {
479 registry: registry,
480 scheme: "https",
481 address: "[2001:db8::1]:",
482 want: "[2001:db8::1]:443",
483 }, {
484 registry: registry,
485 scheme: "https",
486 address: "something:is::wrong]:",
487 want: "something:is::wrong]:",
488 }}
489
490 for _, tt := range tests {
491 got := canonicalAddress(tt.address, tt.scheme)
492 if got != tt.want {
493 t.Errorf("Wrong canonical host: wanted %v got %v", tt.want, got)
494 }
495 }
496 }
497
498 func TestInsufficientScope(t *testing.T) {
499 wrong := "the-wrong-scope"
500 right := "the-right-scope"
501 realm := ""
502 expectedService := "my-service.io"
503 passed := false
504
505 server := httptest.NewServer(
506 http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
507 query := r.URL.Query()
508
509 scopes := query["scope"]
510 switch {
511 case len(scopes) == 0:
512 if !passed {
513 w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=%q,scope=%q", realm, right))
514 w.WriteHeader(http.StatusUnauthorized)
515 }
516 case len(scopes) == 1:
517 w.Write([]byte(`{"token": "arbitrary-token"}`))
518 default:
519 passed = true
520 w.Write([]byte(`{"token": "arbitrary-token-2"}`))
521 }
522 }))
523 defer server.Close()
524
525 basic := &authn.Basic{Username: "foo", Password: "bar"}
526 u, err := url.Parse(server.URL)
527 if err != nil {
528 t.Error("Unexpected error during url.Parse: ", err)
529 }
530 realm = u.Host
531
532 registry, err := name.NewRegistry(expectedService, name.WeakValidation)
533 if err != nil {
534 t.Error("Unexpected error during NewRegistry: ", err)
535 }
536
537 bt := &bearerTransport{
538 inner: http.DefaultTransport,
539 basic: basic,
540 registry: registry,
541 realm: server.URL,
542 scopes: []string{wrong},
543 service: expectedService,
544 scheme: "http",
545 }
546
547 client := http.Client{Transport: bt}
548
549 res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
550 if err != nil {
551 t.Error("Unexpected error during client.Get: ", err)
552 return
553 }
554 if res.StatusCode != http.StatusOK {
555 t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
556 }
557
558 if !passed {
559 t.Error("didn't refresh insufficient scope")
560 }
561 }
562
View as plain text