...

Source file src/github.com/google/go-containerregistry/pkg/v1/remote/transport/bearer_test.go

Documentation: github.com/google/go-containerregistry/pkg/v1/remote/transport

     1  // Copyright 2018 Google LLC All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"net/url"
    23  	"strings"
    24  	"testing"
    25  
    26  	"github.com/google/go-containerregistry/pkg/authn"
    27  	"github.com/google/go-containerregistry/pkg/name"
    28  )
    29  
    30  func TestBearerRefresh(t *testing.T) {
    31  	expectedToken := "Sup3rDup3rS3cr3tz"
    32  	expectedScope := "this-is-your-scope"
    33  	expectedService := "my-service.io"
    34  
    35  	cases := []struct {
    36  		tokenKey string
    37  		wantErr  bool
    38  	}{{
    39  		tokenKey: "token",
    40  		wantErr:  false,
    41  	}, {
    42  		tokenKey: "access_token",
    43  		wantErr:  false,
    44  	}, {
    45  		tokenKey: "tolkien",
    46  		wantErr:  true,
    47  	}}
    48  
    49  	for _, tc := range cases {
    50  		t.Run(tc.tokenKey, func(t *testing.T) {
    51  			server := httptest.NewServer(
    52  				http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    53  					hdr := r.Header.Get("Authorization")
    54  					if !strings.HasPrefix(hdr, "Basic ") {
    55  						t.Errorf("Header.Get(Authorization); got %v, want Basic prefix", hdr)
    56  					}
    57  					if got, want := r.FormValue("scope"), expectedScope; got != want {
    58  						t.Errorf("FormValue(scope); got %v, want %v", got, want)
    59  					}
    60  					if got, want := r.FormValue("service"), expectedService; got != want {
    61  						t.Errorf("FormValue(service); got %v, want %v", got, want)
    62  					}
    63  					w.Write([]byte(fmt.Sprintf(`{%q: %q}`, tc.tokenKey, expectedToken)))
    64  				}))
    65  			defer server.Close()
    66  
    67  			basic := &authn.Basic{Username: "foo", Password: "bar"}
    68  			registry, err := name.NewRegistry(expectedService, name.WeakValidation)
    69  			if err != nil {
    70  				t.Errorf("Unexpected error during NewRegistry: %v", err)
    71  			}
    72  
    73  			bt := &bearerTransport{
    74  				inner:    http.DefaultTransport,
    75  				basic:    basic,
    76  				registry: registry,
    77  				realm:    server.URL,
    78  				scopes:   []string{expectedScope},
    79  				service:  expectedService,
    80  				scheme:   "http",
    81  			}
    82  
    83  			if err := bt.refresh(context.Background()); (err != nil) != tc.wantErr {
    84  				t.Errorf("refresh() = %v", err)
    85  			}
    86  		})
    87  	}
    88  }
    89  
    90  func TestBearerTransport(t *testing.T) {
    91  	expectedToken := "sdkjhfskjdhfkjshdf"
    92  
    93  	blobServer := httptest.NewServer(
    94  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    95  			// We don't expect the blobServer to receive bearer tokens.
    96  			if got := r.Header.Get("Authorization"); got != "" {
    97  				t.Errorf("Header.Get(Authorization); got %v, want empty string", got)
    98  			}
    99  			w.WriteHeader(http.StatusOK)
   100  		}))
   101  	defer blobServer.Close()
   102  
   103  	server := httptest.NewServer(
   104  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   105  			if got, want := r.Header.Get("Authorization"), "Bearer "+expectedToken; got != want {
   106  				t.Errorf("Header.Get(Authorization); got %v, want %v", got, want)
   107  			}
   108  			if r.URL.Path == "/v2/auth" {
   109  				http.Redirect(w, r, "/redirect", http.StatusMovedPermanently)
   110  				return
   111  			}
   112  			if strings.Contains(r.URL.Path, "blobs") {
   113  				http.Redirect(w, r, blobServer.URL, http.StatusFound)
   114  				return
   115  			}
   116  			w.WriteHeader(http.StatusOK)
   117  		}))
   118  	defer server.Close()
   119  
   120  	u, err := url.Parse(server.URL)
   121  	if err != nil {
   122  		t.Errorf("Unexpected error during url.Parse: %v", err)
   123  	}
   124  	registry, err := name.NewRegistry(u.Host, name.WeakValidation)
   125  	if err != nil {
   126  		t.Errorf("Unexpected error during NewRegistry: %v", err)
   127  	}
   128  
   129  	client := http.Client{Transport: &bearerTransport{
   130  		inner:    &http.Transport{},
   131  		bearer:   authn.AuthConfig{RegistryToken: expectedToken},
   132  		registry: registry,
   133  		scheme:   "http",
   134  	}}
   135  
   136  	_, err = client.Get(fmt.Sprintf("http://%s/v2/auth", u.Host))
   137  	if err != nil {
   138  		t.Errorf("Unexpected error during Get: %v", err)
   139  	}
   140  
   141  	_, err = client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
   142  	if err != nil {
   143  		t.Errorf("Unexpected error during Get: %v", err)
   144  	}
   145  }
   146  
   147  func TestBearerTransportTokenRefresh(t *testing.T) {
   148  	initialToken := "foo"
   149  	refreshedToken := "bar"
   150  
   151  	server := httptest.NewServer(
   152  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   153  			hdr := r.Header.Get("Authorization")
   154  			if hdr == "Bearer "+refreshedToken {
   155  				w.WriteHeader(http.StatusOK)
   156  				return
   157  			}
   158  			if strings.HasPrefix(hdr, "Basic ") {
   159  				w.Write([]byte(fmt.Sprintf(`{"token": %q}`, refreshedToken)))
   160  			}
   161  
   162  			w.Header().Set("WWW-Authenticate", "scope=foo")
   163  			w.WriteHeader(http.StatusUnauthorized)
   164  		}))
   165  	defer server.Close()
   166  
   167  	u, err := url.Parse(server.URL)
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  	registry, err := name.NewRegistry(u.Host, name.WeakValidation)
   172  	if err != nil {
   173  		t.Fatalf("Unexpected error during NewRegistry: %v", err)
   174  	}
   175  
   176  	// Pass Username/Password
   177  	transport := &bearerTransport{
   178  		inner:    http.DefaultTransport,
   179  		bearer:   authn.AuthConfig{RegistryToken: initialToken},
   180  		basic:    &authn.Basic{Username: "foo", Password: "bar"},
   181  		registry: registry,
   182  		realm:    server.URL,
   183  		scheme:   "http",
   184  	}
   185  	client := http.Client{Transport: transport}
   186  
   187  	res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
   188  	if err != nil {
   189  		t.Errorf("Unexpected error during client.Get: %v", err)
   190  		return
   191  	}
   192  	if res.StatusCode != http.StatusOK {
   193  		t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
   194  	}
   195  	if got, want := transport.bearer.RegistryToken, refreshedToken; got != want {
   196  		t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
   197  	}
   198  
   199  	// Pass RegistryToken directly
   200  	transport.bearer = authn.AuthConfig{RegistryToken: initialToken}
   201  	transport.basic = &authn.Bearer{Token: refreshedToken}
   202  	client = http.Client{Transport: transport}
   203  
   204  	res, err = client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
   205  	if err != nil {
   206  		t.Errorf("Unexpected error during client.Get: %v", err)
   207  		return
   208  	}
   209  	if res.StatusCode != http.StatusOK {
   210  		t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
   211  	}
   212  	if got, want := transport.bearer.RegistryToken, refreshedToken; got != want {
   213  		t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
   214  	}
   215  }
   216  
   217  func TestBearerTransportOauthRefresh(t *testing.T) {
   218  	initialToken := "foo"
   219  	accessToken := "bar"
   220  	refreshToken := "baz"
   221  
   222  	server := httptest.NewServer(
   223  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   224  			if r.Method == http.MethodPost {
   225  				if err := r.ParseForm(); err != nil {
   226  					t.Fatal(err)
   227  				}
   228  				if it := r.FormValue("refresh_token"); it != initialToken {
   229  					t.Errorf("want %s got %s", initialToken, it)
   230  				}
   231  				w.WriteHeader(http.StatusOK)
   232  				w.Write([]byte(fmt.Sprintf(`{"access_token": %q, "refresh_token": %q}`, accessToken, refreshToken)))
   233  				return
   234  			}
   235  
   236  			hdr := r.Header.Get("Authorization")
   237  			if hdr == "Bearer "+accessToken {
   238  				w.WriteHeader(http.StatusOK)
   239  				return
   240  			}
   241  
   242  			w.Header().Set("WWW-Authenticate", "scope=foo")
   243  			w.WriteHeader(http.StatusUnauthorized)
   244  		}))
   245  	defer server.Close()
   246  
   247  	u, err := url.Parse(server.URL)
   248  	if err != nil {
   249  		t.Fatal(err)
   250  	}
   251  	registry, err := name.NewRegistry(u.Host, name.WeakValidation)
   252  	if err != nil {
   253  		t.Errorf("Unexpected error during NewRegistry: %v", err)
   254  	}
   255  
   256  	transport := &bearerTransport{
   257  		inner:    http.DefaultTransport,
   258  		basic:    authn.FromConfig(authn.AuthConfig{IdentityToken: initialToken}),
   259  		registry: registry,
   260  		realm:    server.URL,
   261  		scheme:   "http",
   262  		scopes:   []string{"myscope"},
   263  		service:  u.Host,
   264  	}
   265  	client := http.Client{Transport: transport}
   266  
   267  	res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
   268  	if err != nil {
   269  		t.Fatalf("Unexpected error during client.Get: %v", err)
   270  	}
   271  	if res.StatusCode != http.StatusOK {
   272  		t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
   273  	}
   274  	if want, got := transport.bearer.RegistryToken, accessToken; want != got {
   275  		t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
   276  	}
   277  	basicAuthConfig, err := transport.basic.Authorization()
   278  	if err != nil {
   279  		t.Fatal(err)
   280  	}
   281  	if got, want := basicAuthConfig.IdentityToken, refreshToken; got != want {
   282  		t.Errorf("Expected Basic IdentityToken to be refreshed, got %v, want %v", got, want)
   283  	}
   284  }
   285  
   286  func TestBearerTransportOauth404Fallback(t *testing.T) {
   287  	basicAuth := "basic_auth"
   288  	identityToken := "identity_token"
   289  	accessToken := "access_token"
   290  
   291  	server := httptest.NewServer(
   292  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   293  			if r.Method == http.MethodPost {
   294  				w.WriteHeader(http.StatusNotFound)
   295  			}
   296  
   297  			hdr := r.Header.Get("Authorization")
   298  			if hdr == "Basic "+basicAuth {
   299  				w.WriteHeader(http.StatusOK)
   300  				w.Write([]byte(fmt.Sprintf(`{"access_token": %q}`, accessToken)))
   301  			}
   302  			if hdr == "Bearer "+accessToken {
   303  				w.WriteHeader(http.StatusOK)
   304  				return
   305  			}
   306  
   307  			w.Header().Set("WWW-Authenticate", "scope=foo")
   308  			w.WriteHeader(http.StatusUnauthorized)
   309  		}))
   310  	defer server.Close()
   311  
   312  	u, err := url.Parse(server.URL)
   313  	if err != nil {
   314  		t.Fatal(err)
   315  	}
   316  	registry, err := name.NewRegistry(u.Host, name.WeakValidation)
   317  	if err != nil {
   318  		t.Errorf("Unexpected error during NewRegistry: %v", err)
   319  	}
   320  
   321  	transport := &bearerTransport{
   322  		inner: http.DefaultTransport,
   323  		basic: authn.FromConfig(authn.AuthConfig{
   324  			IdentityToken: identityToken,
   325  			Auth:          basicAuth,
   326  		}),
   327  		registry: registry,
   328  		realm:    server.URL,
   329  		scheme:   "http",
   330  		scopes:   []string{"myscope"},
   331  		service:  u.Host,
   332  	}
   333  	client := http.Client{Transport: transport}
   334  
   335  	res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
   336  	if err != nil {
   337  		t.Fatalf("Unexpected error during client.Get: %v", err)
   338  	}
   339  	if res.StatusCode != http.StatusOK {
   340  		t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
   341  	}
   342  	if got, want := transport.bearer.RegistryToken, accessToken; got != want {
   343  		t.Errorf("Expected Bearer token to be refreshed, got %v, want %v", got, want)
   344  	}
   345  }
   346  
   347  type recorder struct {
   348  	reqs []*http.Request
   349  	resp *http.Response
   350  	err  error
   351  }
   352  
   353  func newRecorder(resp *http.Response, err error) *recorder {
   354  	return &recorder{
   355  		reqs: []*http.Request{},
   356  		resp: resp,
   357  		err:  err,
   358  	}
   359  }
   360  
   361  func (r *recorder) RoundTrip(in *http.Request) (*http.Response, error) {
   362  	r.reqs = append(r.reqs, in)
   363  	return r.resp, r.err
   364  }
   365  
   366  func TestSchemeOverride(t *testing.T) {
   367  	// Record the requests we get in the inner transport.
   368  	cannedResponse := http.Response{
   369  		Status:     http.StatusText(http.StatusOK),
   370  		StatusCode: http.StatusOK,
   371  	}
   372  	rec := newRecorder(&cannedResponse, nil)
   373  	registry, err := name.NewRegistry("example.com")
   374  	if err != nil {
   375  		t.Fatalf("Unexpected error during NewRegistry: %v", err)
   376  	}
   377  	st := &schemeTransport{
   378  		inner:    rec,
   379  		registry: registry,
   380  		scheme:   "http",
   381  	}
   382  
   383  	// We should see the scheme be overridden to "http" for the registry, but the
   384  	// scheme for the token server should be unchanged.
   385  	tests := []struct {
   386  		url        string
   387  		wantScheme string
   388  	}{{
   389  		url:        "https://example.com",
   390  		wantScheme: "http",
   391  	}, {
   392  		url:        "https://token.example.com",
   393  		wantScheme: "https",
   394  	}}
   395  
   396  	for i, tt := range tests {
   397  		req, err := http.NewRequest("GET", tt.url, nil)
   398  		if err != nil {
   399  			t.Fatalf("Unexpected error during NewRequest: %v", err)
   400  		}
   401  
   402  		if _, err := st.RoundTrip(req); err != nil {
   403  			t.Fatalf("Unexpected error during RoundTrip: %v", err)
   404  		}
   405  
   406  		if got, want := rec.reqs[i].URL.Scheme, tt.wantScheme; got != want {
   407  			t.Errorf("Wrong scheme: wanted %v, got %v", want, got)
   408  		}
   409  	}
   410  }
   411  
   412  func TestCanonicalAddressResolution(t *testing.T) {
   413  	registry, err := name.NewRegistry("does-not-matter", name.WeakValidation)
   414  	if err != nil {
   415  		t.Errorf("Unexpected error during NewRegistry: %v", err)
   416  	}
   417  
   418  	tests := []struct {
   419  		registry name.Registry
   420  		scheme   string
   421  		address  string
   422  		want     string
   423  	}{{
   424  		registry: registry,
   425  		scheme:   "http",
   426  		address:  "registry.example.com",
   427  		want:     "registry.example.com:80",
   428  	}, {
   429  		registry: registry,
   430  		scheme:   "http",
   431  		address:  "registry.example.com:12345",
   432  		want:     "registry.example.com:12345",
   433  	}, {
   434  		registry: registry,
   435  		scheme:   "https",
   436  		address:  "registry.example.com",
   437  		want:     "registry.example.com:443",
   438  	}, {
   439  		registry: registry,
   440  		scheme:   "https",
   441  		address:  "registry.example.com:12345",
   442  		want:     "registry.example.com:12345",
   443  	}, {
   444  		registry: registry,
   445  		scheme:   "http",
   446  		address:  "registry.example.com:",
   447  		want:     "registry.example.com:80",
   448  	}, {
   449  		registry: registry,
   450  		scheme:   "https",
   451  		address:  "registry.example.com:",
   452  		want:     "registry.example.com:443",
   453  	}, {
   454  		registry: registry,
   455  		scheme:   "http",
   456  		address:  "2001:db8::1",
   457  		want:     "[2001:db8::1]:80",
   458  	}, {
   459  		registry: registry,
   460  		scheme:   "https",
   461  		address:  "2001:db8::1",
   462  		want:     "[2001:db8::1]:443",
   463  	}, {
   464  		registry: registry,
   465  		scheme:   "http",
   466  		address:  "[2001:db8::1]:12345",
   467  		want:     "[2001:db8::1]:12345",
   468  	}, {
   469  		registry: registry,
   470  		scheme:   "https",
   471  		address:  "[2001:db8::1]:12345",
   472  		want:     "[2001:db8::1]:12345",
   473  	}, {
   474  		registry: registry,
   475  		scheme:   "http",
   476  		address:  "[2001:db8::1]:",
   477  		want:     "[2001:db8::1]:80",
   478  	}, {
   479  		registry: registry,
   480  		scheme:   "https",
   481  		address:  "[2001:db8::1]:",
   482  		want:     "[2001:db8::1]:443",
   483  	}, {
   484  		registry: registry,
   485  		scheme:   "https",
   486  		address:  "something:is::wrong]:",
   487  		want:     "something:is::wrong]:",
   488  	}}
   489  
   490  	for _, tt := range tests {
   491  		got := canonicalAddress(tt.address, tt.scheme)
   492  		if got != tt.want {
   493  			t.Errorf("Wrong canonical host: wanted %v got %v", tt.want, got)
   494  		}
   495  	}
   496  }
   497  
   498  func TestInsufficientScope(t *testing.T) {
   499  	wrong := "the-wrong-scope"
   500  	right := "the-right-scope"
   501  	realm := ""
   502  	expectedService := "my-service.io"
   503  	passed := false
   504  
   505  	server := httptest.NewServer(
   506  		http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   507  			query := r.URL.Query()
   508  
   509  			scopes := query["scope"]
   510  			switch {
   511  			case len(scopes) == 0:
   512  				if !passed {
   513  					w.Header().Set("WWW-Authenticate", fmt.Sprintf("Bearer realm=%q,scope=%q", realm, right))
   514  					w.WriteHeader(http.StatusUnauthorized)
   515  				}
   516  			case len(scopes) == 1:
   517  				w.Write([]byte(`{"token": "arbitrary-token"}`))
   518  			default:
   519  				passed = true
   520  				w.Write([]byte(`{"token": "arbitrary-token-2"}`))
   521  			}
   522  		}))
   523  	defer server.Close()
   524  
   525  	basic := &authn.Basic{Username: "foo", Password: "bar"}
   526  	u, err := url.Parse(server.URL)
   527  	if err != nil {
   528  		t.Error("Unexpected error during url.Parse: ", err)
   529  	}
   530  	realm = u.Host
   531  
   532  	registry, err := name.NewRegistry(expectedService, name.WeakValidation)
   533  	if err != nil {
   534  		t.Error("Unexpected error during NewRegistry: ", err)
   535  	}
   536  
   537  	bt := &bearerTransport{
   538  		inner:    http.DefaultTransport,
   539  		basic:    basic,
   540  		registry: registry,
   541  		realm:    server.URL,
   542  		scopes:   []string{wrong},
   543  		service:  expectedService,
   544  		scheme:   "http",
   545  	}
   546  
   547  	client := http.Client{Transport: bt}
   548  
   549  	res, err := client.Get(fmt.Sprintf("http://%s/v2/foo/bar/blobs/blah", u.Host))
   550  	if err != nil {
   551  		t.Error("Unexpected error during client.Get: ", err)
   552  		return
   553  	}
   554  	if res.StatusCode != http.StatusOK {
   555  		t.Errorf("client.Get final StatusCode got %v, want: %v", res.StatusCode, http.StatusOK)
   556  	}
   557  
   558  	if !passed {
   559  		t.Error("didn't refresh insufficient scope")
   560  	}
   561  }
   562  

View as plain text