...

Source file src/oras.land/oras-go/pkg/registry/remote/auth/client.go

Documentation: oras.land/oras-go/pkg/registry/remote/auth

     1  /*
     2  Copyright The ORAS Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  package auth
    16  
    17  import (
    18  	"context"
    19  	"encoding/base64"
    20  	"encoding/json"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  	"net/http"
    25  	"net/url"
    26  	"strings"
    27  
    28  	"oras.land/oras-go/pkg/registry/remote/internal/errutil"
    29  )
    30  
    31  // DefaultClient is the default auth-decorated client.
    32  var DefaultClient = &Client{
    33  	Header: http.Header{
    34  		"User-Agent": {"oras-go"},
    35  	},
    36  	Cache: DefaultCache,
    37  }
    38  
    39  // maxResponseBytes specifies the default limit on how many response bytes are
    40  // allowed in the server's response from authorization service servers.
    41  // A typical response message from authorization service servers is around 1 to
    42  // 4 KiB. Since the size of a token must be smaller than the HTTP header size
    43  // limit, which is usually 16 KiB. As specified by the distribution, the
    44  // response may contain 2 identical tokens, that is, 16 x 2 = 32 KiB.
    45  // Hence, 128 KiB should be sufficient.
    46  // References: https://docs.docker.com/registry/spec/auth/token/
    47  var maxResponseBytes int64 = 128 * 1024 // 128 KiB
    48  
    49  // defaultClientID specifies the default client ID used in OAuth2.
    50  // See also ClientID.
    51  var defaultClientID = "oras-go"
    52  
    53  // Client is an auth-decorated HTTP client.
    54  // Its zero value is a usable client that uses http.DefaultClient with no cache.
    55  type Client struct {
    56  	// Client is the underlying HTTP client used to access the remote
    57  	// server.
    58  	// If nil, http.DefaultClient is used.
    59  	Client *http.Client
    60  
    61  	// Header contains the custom headers to be added to each request.
    62  	Header http.Header
    63  
    64  	// Credential specifies the function for resolving the credential for the
    65  	// given registry (i.e. host:port).
    66  	// `EmptyCredential` is a valid return value and should not be considered as
    67  	// an error.
    68  	// If nil, the credential is always resolved to `EmptyCredential`.
    69  	Credential func(context.Context, string) (Credential, error)
    70  
    71  	// Cache caches credentials for direct accessing the remote registry.
    72  	// If nil, no cache is used.
    73  	Cache Cache
    74  
    75  	// ClientID used in fetching OAuth2 token as a required field.
    76  	// If empty, a default client ID is used.
    77  	// Reference: https://docs.docker.com/registry/spec/auth/oauth/#getting-a-token
    78  	ClientID string
    79  
    80  	// ForceAttemptOAuth2 controls whether to follow OAuth2 with password grant
    81  	// instead the distribution spec when authenticating using username and
    82  	// password.
    83  	// References:
    84  	// - https://docs.docker.com/registry/spec/auth/jwt/
    85  	// - https://docs.docker.com/registry/spec/auth/oauth/
    86  	ForceAttemptOAuth2 bool
    87  }
    88  
    89  // client returns an HTTP client used to access the remote registry.
    90  // http.DefaultClient is return if the client is not configured.
    91  func (c *Client) client() *http.Client {
    92  	if c.Client == nil {
    93  		return http.DefaultClient
    94  	}
    95  	return c.Client
    96  }
    97  
    98  // send adds headers to the request and sends the request to the remote server.
    99  func (c *Client) send(req *http.Request) (*http.Response, error) {
   100  	for key, values := range c.Header {
   101  		req.Header[key] = append(req.Header[key], values...)
   102  	}
   103  	return c.client().Do(req)
   104  }
   105  
   106  // credential resolves the credential for the given registry.
   107  func (c *Client) credential(ctx context.Context, reg string) (Credential, error) {
   108  	if c.Credential == nil {
   109  		return EmptyCredential, nil
   110  	}
   111  	return c.Credential(ctx, reg)
   112  }
   113  
   114  // cache resolves the cache.
   115  // noCache is return if the cache is not configured.
   116  func (c *Client) cache() Cache {
   117  	if c.Cache == nil {
   118  		return noCache{}
   119  	}
   120  	return c.Cache
   121  }
   122  
   123  // SetUserAgent sets the user agent for all out-going requests.
   124  func (c *Client) SetUserAgent(userAgent string) {
   125  	if c.Header == nil {
   126  		c.Header = http.Header{}
   127  	}
   128  	c.Header.Set("User-Agent", userAgent)
   129  }
   130  
   131  // Do sends the request to the remote server with resolving authentication
   132  // attempted.
   133  // On authentication failure due to bad credential,
   134  // - Do returns error if it fails to fetch token for bearer auth.
   135  // - Do returns the registry response without error for basic auth.
   136  func (c *Client) Do(originalReq *http.Request) (*http.Response, error) {
   137  	ctx := originalReq.Context()
   138  	req := originalReq.Clone(ctx)
   139  
   140  	// attempt cached auth token
   141  	var attemptedKey string
   142  	cache := c.cache()
   143  	registry := originalReq.Host
   144  	scheme, err := cache.GetScheme(ctx, registry)
   145  	if err == nil {
   146  		switch scheme {
   147  		case SchemeBasic:
   148  			token, err := cache.GetToken(ctx, registry, SchemeBasic, "")
   149  			if err == nil {
   150  				req.Header.Set("Authorization", "Basic "+token)
   151  			}
   152  		case SchemeBearer:
   153  			scopes := GetScopes(ctx)
   154  			attemptedKey = strings.Join(scopes, " ")
   155  			token, err := cache.GetToken(ctx, registry, SchemeBearer, attemptedKey)
   156  			if err == nil {
   157  				req.Header.Set("Authorization", "Bearer "+token)
   158  			}
   159  		}
   160  	}
   161  
   162  	resp, err := c.send(req)
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	if resp.StatusCode != http.StatusUnauthorized {
   167  		return resp, nil
   168  	}
   169  
   170  	// attempt again with credentials for recognized schemes
   171  	challenge := resp.Header.Get("Www-Authenticate")
   172  	scheme, params := parseChallenge(challenge)
   173  	switch scheme {
   174  	case SchemeBasic:
   175  		resp.Body.Close()
   176  
   177  		token, err := cache.Set(ctx, registry, SchemeBasic, "", func(ctx context.Context) (string, error) {
   178  			return c.fetchBasicAuth(ctx, registry)
   179  		})
   180  		if err != nil {
   181  			return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
   182  		}
   183  
   184  		req = originalReq.Clone(ctx)
   185  		req.Header.Set("Authorization", "Basic "+token)
   186  	case SchemeBearer:
   187  		resp.Body.Close()
   188  
   189  		// merge hinted scopes with challenged scopes
   190  		scopes := GetScopes(ctx)
   191  		if scope := params["scope"]; scope != "" {
   192  			scopes = append(scopes, strings.Split(scope, " ")...)
   193  			scopes = CleanScopes(scopes)
   194  		}
   195  		key := strings.Join(scopes, " ")
   196  
   197  		// attempt the cache again if there is a scope change
   198  		if key != attemptedKey {
   199  			if token, err := cache.GetToken(ctx, registry, SchemeBearer, key); err == nil {
   200  				req = originalReq.Clone(ctx)
   201  				req.Header.Set("Authorization", "Bearer "+token)
   202  
   203  				resp, err := c.send(req)
   204  				if err != nil {
   205  					return nil, err
   206  				}
   207  				if resp.StatusCode != http.StatusUnauthorized {
   208  					return resp, nil
   209  				}
   210  				resp.Body.Close()
   211  			}
   212  		}
   213  
   214  		// attempt with credentials
   215  		realm := params["realm"]
   216  		service := params["service"]
   217  		token, err := cache.Set(ctx, registry, SchemeBearer, key, func(ctx context.Context) (string, error) {
   218  			return c.fetchBearerToken(ctx, registry, realm, service, scopes)
   219  		})
   220  		if err != nil {
   221  			return nil, fmt.Errorf("%s %q: %w", resp.Request.Method, resp.Request.URL, err)
   222  		}
   223  
   224  		req = originalReq.Clone(ctx)
   225  		req.Header.Set("Authorization", "Bearer "+token)
   226  	default:
   227  		return resp, nil
   228  	}
   229  
   230  	return c.send(req)
   231  }
   232  
   233  // fetchBasicAuth fetches a basic auth token for the basic challenge.
   234  func (c *Client) fetchBasicAuth(ctx context.Context, registry string) (string, error) {
   235  	cred, err := c.credential(ctx, registry)
   236  	if err != nil {
   237  		return "", fmt.Errorf("failed to resolve credential: %w", err)
   238  	}
   239  	if cred == EmptyCredential {
   240  		return "", errors.New("credential required for basic auth")
   241  	}
   242  	if cred.Username == "" || cred.Password == "" {
   243  		return "", errors.New("missing username or password for basic auth")
   244  	}
   245  	auth := cred.Username + ":" + cred.Password
   246  	return base64.StdEncoding.EncodeToString([]byte(auth)), nil
   247  }
   248  
   249  // fetchBearerToken fetches an access token for the bearer challenge.
   250  func (c *Client) fetchBearerToken(ctx context.Context, registry, realm, service string, scopes []string) (string, error) {
   251  	cred, err := c.credential(ctx, registry)
   252  	if err != nil {
   253  		return "", err
   254  	}
   255  	if cred.AccessToken != "" {
   256  		return cred.AccessToken, nil
   257  	}
   258  	if cred == EmptyCredential || (cred.RefreshToken == "" && !c.ForceAttemptOAuth2) {
   259  		return c.fetchDistributionToken(ctx, realm, service, scopes, cred.Username, cred.Password)
   260  	}
   261  	return c.fetchOAuth2Token(ctx, realm, service, scopes, cred)
   262  }
   263  
   264  // fetchDistributionToken fetches an access token as defined by the distribution
   265  // specification.
   266  // It fetches anonymous tokens if no credential is provided.
   267  // References:
   268  // - https://docs.docker.com/registry/spec/auth/jwt/
   269  // - https://docs.docker.com/registry/spec/auth/token/
   270  func (c *Client) fetchDistributionToken(ctx context.Context, realm, service string, scopes []string, username, password string) (string, error) {
   271  	req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm, nil)
   272  	if err != nil {
   273  		return "", err
   274  	}
   275  	if username != "" || password != "" {
   276  		req.SetBasicAuth(username, password)
   277  	}
   278  	q := req.URL.Query()
   279  	if service != "" {
   280  		q.Set("service", service)
   281  	}
   282  	for _, scope := range scopes {
   283  		q.Add("scope", scope)
   284  	}
   285  	req.URL.RawQuery = q.Encode()
   286  
   287  	resp, err := c.send(req)
   288  	if err != nil {
   289  		return "", err
   290  	}
   291  	defer resp.Body.Close()
   292  	if resp.StatusCode != http.StatusOK {
   293  		return "", errutil.ParseErrorResponse(resp)
   294  	}
   295  
   296  	// As specified in https://docs.docker.com/registry/spec/auth/token/ section
   297  	// "Token Response Fields", the token is either in `token` or
   298  	// `access_token`. If both present, they are identical.
   299  	var result struct {
   300  		Token       string `json:"token"`
   301  		AccessToken string `json:"access_token"`
   302  	}
   303  	lr := io.LimitReader(resp.Body, maxResponseBytes)
   304  	if err := json.NewDecoder(lr).Decode(&result); err != nil {
   305  		return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
   306  	}
   307  	if result.AccessToken != "" {
   308  		return result.AccessToken, nil
   309  	}
   310  	if result.Token != "" {
   311  		return result.Token, nil
   312  	}
   313  	return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
   314  }
   315  
   316  // fetchOAuth2Token fetches an OAuth2 access token.
   317  // Reference: https://docs.docker.com/registry/spec/auth/oauth/
   318  func (c *Client) fetchOAuth2Token(ctx context.Context, realm, service string, scopes []string, cred Credential) (string, error) {
   319  	form := url.Values{}
   320  	if cred.RefreshToken != "" {
   321  		form.Set("grant_type", "refresh_token")
   322  		form.Set("refresh_token", cred.RefreshToken)
   323  	} else if cred.Username != "" && cred.Password != "" {
   324  		form.Set("grant_type", "password")
   325  		form.Set("username", cred.Username)
   326  		form.Set("password", cred.Password)
   327  	} else {
   328  		return "", errors.New("missing username or password for bearer auth")
   329  	}
   330  	form.Set("service", service)
   331  	clientID := c.ClientID
   332  	if clientID == "" {
   333  		clientID = defaultClientID
   334  	}
   335  	form.Set("client_id", clientID)
   336  	if len(scopes) != 0 {
   337  		form.Set("scope", strings.Join(scopes, " "))
   338  	}
   339  	body := strings.NewReader(form.Encode())
   340  
   341  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm, body)
   342  	if err != nil {
   343  		return "", err
   344  	}
   345  	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   346  
   347  	resp, err := c.send(req)
   348  	if err != nil {
   349  		return "", err
   350  	}
   351  	defer resp.Body.Close()
   352  	if resp.StatusCode != http.StatusOK {
   353  		return "", errutil.ParseErrorResponse(resp)
   354  	}
   355  
   356  	var result struct {
   357  		AccessToken string `json:"access_token"`
   358  	}
   359  	lr := io.LimitReader(resp.Body, maxResponseBytes)
   360  	if err := json.NewDecoder(lr).Decode(&result); err != nil {
   361  		return "", fmt.Errorf("%s %q: failed to decode response: %w", resp.Request.Method, resp.Request.URL, err)
   362  	}
   363  	if result.AccessToken != "" {
   364  		return result.AccessToken, nil
   365  	}
   366  	return "", fmt.Errorf("%s %q: empty token returned", resp.Request.Method, resp.Request.URL)
   367  }
   368  

View as plain text