...

Source file src/cuelabs.dev/go/oci/ociregistry/ociauth/auth.go

Documentation: cuelabs.dev/go/oci/ociregistry/ociauth

     1  package ociauth
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"net/url"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"cuelabs.dev/go/oci/ociregistry/internal/exp/slices"
    16  )
    17  
    18  // TODO decide on a good value for this.
    19  const oauthClientID = "cuelabs-ociauth"
    20  
    21  var ErrNoAuth = fmt.Errorf("no authorization token available to add to request")
    22  
    23  // stdTransport implements [http.RoundTripper] by acquiring authorization tokens
    24  // using the flows implemented
    25  // by the usual docker clients. Note that this is _not_ documented as
    26  // part of any official OCI spec.
    27  //
    28  // See https://distribution.github.io/distribution/spec/auth/token/ for an overview.
    29  type stdTransport struct {
    30  	config     Config
    31  	transport  http.RoundTripper
    32  	mu         sync.Mutex
    33  	registries map[string]*registry
    34  }
    35  
    36  type StdTransportParams struct {
    37  	// Config represents the underlying configuration file information.
    38  	// It is consulted for authorization information on the hosts
    39  	// to which the HTTP requests are made.
    40  	Config Config
    41  
    42  	// HTTPClient is used to make the underlying HTTP requests.
    43  	// If it's nil, [http.DefaultTransport] will be used.
    44  	Transport http.RoundTripper
    45  }
    46  
    47  // NewStdTransport returns an [http.RoundTripper] implementation that
    48  // acquires authorization tokens using the flows implemented by the
    49  // usual docker clients. Note that this is _not_ documented as part of
    50  // any official OCI spec.
    51  //
    52  // See https://distribution.github.io/distribution/spec/auth/token/ for an overview.
    53  //
    54  // The RoundTrip method acquires authorization before invoking the
    55  // request. request. It may invoke the request more than once, and can
    56  // use [http.Request.GetBody] to reset the request body if it gets
    57  // consumed.
    58  //
    59  // It ensures that the authorization token used will have at least the
    60  // capability to execute operations in the required scope associated
    61  // with the request context (see [ContextWithRequestInfo]). Any other
    62  // auth scope inside the context (see [ContextWithScope]) may also be
    63  // taken into account when acquiring new tokens.
    64  func NewStdTransport(p StdTransportParams) http.RoundTripper {
    65  	if p.Config == nil {
    66  		p.Config = emptyConfig{}
    67  	}
    68  	if p.Transport == nil {
    69  		p.Transport = http.DefaultTransport
    70  	}
    71  	return &stdTransport{
    72  		config:     p.Config,
    73  		transport:  p.Transport,
    74  		registries: make(map[string]*registry),
    75  	}
    76  }
    77  
    78  // registry holds currently known auth information for a registry.
    79  type registry struct {
    80  	host      string
    81  	transport http.RoundTripper
    82  	config    Config
    83  	initOnce  sync.Once
    84  	initErr   error
    85  
    86  	// mu guards the fields that follow it.
    87  	mu sync.Mutex
    88  
    89  	// wwwAuthenticate holds the Www-Authenticate header from
    90  	// the most recent 401 response. If there was a 401 response
    91  	// that didn't hold such a header, this will still be non-nil
    92  	// but hold a zero authHeader.
    93  	wwwAuthenticate *authHeader
    94  
    95  	accessTokens []*scopedToken
    96  	refreshToken string
    97  	basic        *userPass
    98  }
    99  
   100  type scopedToken struct {
   101  	// scope holds the scope that the token is good for.
   102  	scope Scope
   103  	// token holds the actual access token.
   104  	token string
   105  	// expires holds when the token expires.
   106  	expires time.Time
   107  }
   108  
   109  type userPass struct {
   110  	username string
   111  	password string
   112  }
   113  
   114  var forever = time.Date(99999, time.January, 1, 0, 0, 0, 0, time.UTC)
   115  
   116  // RoundTrip implements [http.RoundTripper.RoundTrip].
   117  func (a *stdTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   118  	// From the [http.RoundTripper] docs:
   119  	//	RoundTrip should not modify the request, except for
   120  	//	consuming and closing the Request's Body.
   121  	req = req.Clone(req.Context())
   122  
   123  	// From the [http.RoundTripper] docs:
   124  	//	RoundTrip must always close the body, including on errors, [...]
   125  	needBodyClose := true
   126  	defer func() {
   127  		if needBodyClose && req.Body != nil {
   128  			req.Body.Close()
   129  		}
   130  	}()
   131  
   132  	a.mu.Lock()
   133  	r := a.registries[req.URL.Host]
   134  	if r == nil {
   135  		r = &registry{
   136  			host:      req.URL.Host,
   137  			config:    a.config,
   138  			transport: a.transport,
   139  		}
   140  		a.registries[r.host] = r
   141  	}
   142  	a.mu.Unlock()
   143  	if err := r.init(); err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	ctx := req.Context()
   148  	requiredScope := RequestInfoFromContext(ctx).RequiredScope
   149  	wantScope := ScopeFromContext(ctx)
   150  
   151  	if err := r.setAuthorization(ctx, req, requiredScope, wantScope); err != nil {
   152  		return nil, err
   153  	}
   154  	resp, err := r.transport.RoundTrip(req)
   155  
   156  	// The underlying transport should now have closed the request body
   157  	// so we don't have to.
   158  	needBodyClose = false
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	if resp.StatusCode != http.StatusUnauthorized {
   163  		return resp, nil
   164  	}
   165  	challenge := challengeFromResponse(resp)
   166  	if challenge == nil {
   167  		return resp, nil
   168  	}
   169  	authAdded, err := r.setAuthorizationFromChallenge(ctx, req, challenge, requiredScope, wantScope)
   170  	if err != nil {
   171  		resp.Body.Close()
   172  		return nil, err
   173  	}
   174  	if !authAdded {
   175  		// Couldn't acquire any more authorization than we had initially.
   176  		return resp, nil
   177  	}
   178  	resp.Body.Close()
   179  	// rewind request body if needed and possible.
   180  	if req.GetBody != nil {
   181  		req.Body, err = req.GetBody()
   182  		if err != nil {
   183  			return nil, err
   184  		}
   185  	}
   186  	return r.transport.RoundTrip(req)
   187  }
   188  
   189  // setAuthorization sets up authorization on the given request using any
   190  // auth information currently available.
   191  func (r *registry) setAuthorization(ctx context.Context, req *http.Request, requiredScope, wantScope Scope) error {
   192  	r.mu.Lock()
   193  	defer r.mu.Unlock()
   194  	// Remove tokens that have expired or will expire soon so that
   195  	// the caller doesn't start using a token only for it to expire while it's
   196  	// making the request.
   197  	r.deleteExpiredTokens(time.Now().UTC().Add(time.Second))
   198  
   199  	if accessToken := r.accessTokenForScope(requiredScope); accessToken != nil {
   200  		// We have a potentially valid access token. Use it.
   201  		req.Header.Set("Authorization", "Bearer "+accessToken.token)
   202  		return nil
   203  	}
   204  	if r.wwwAuthenticate == nil {
   205  		// We haven't seen a 401 response yet. Avoid putting any
   206  		// basic authorization in the request, because that can mean that
   207  		// the server sends a 401 response without a Www-Authenticate
   208  		// header.
   209  		return nil
   210  	}
   211  	if r.refreshToken != "" && r.wwwAuthenticate.scheme == "bearer" {
   212  		// We've got a refresh token that we can use to try to
   213  		// acquire an access token and we've seen a Www-Authenticate response
   214  		// that tells us how we can use it.
   215  
   216  		// TODO we're holding the lock (r.mu) here, which is precluding
   217  		// acquiring several tokens concurrently. We should relax the lock
   218  		// to allow that.
   219  
   220  		accessToken, err := r.acquireAccessToken(ctx, requiredScope, wantScope)
   221  		if err != nil {
   222  			return err
   223  		}
   224  		req.Header.Set("Authorization", "Bearer "+accessToken)
   225  		return nil
   226  	}
   227  	if r.wwwAuthenticate.scheme != "bearer" && r.basic != nil {
   228  		req.SetBasicAuth(r.basic.username, r.basic.password)
   229  		return nil
   230  	}
   231  	return nil
   232  }
   233  
   234  func (r *registry) setAuthorizationFromChallenge(ctx context.Context, req *http.Request, challenge *authHeader, requiredScope, wantScope Scope) (bool, error) {
   235  	r.mu.Lock()
   236  	defer r.mu.Unlock()
   237  	r.wwwAuthenticate = challenge
   238  
   239  	switch {
   240  	case r.wwwAuthenticate.scheme == "bearer":
   241  		scope := ParseScope(r.wwwAuthenticate.params["scope"])
   242  		accessToken, err := r.acquireAccessToken(ctx, scope, wantScope.Union(requiredScope))
   243  		if err != nil {
   244  			return false, err
   245  		}
   246  		req.Header.Set("Authorization", "Bearer "+accessToken)
   247  		return true, nil
   248  	case r.basic != nil:
   249  		req.SetBasicAuth(r.basic.username, r.basic.password)
   250  		return true, nil
   251  	}
   252  	return false, nil
   253  }
   254  
   255  // init initializes the registry instance by acquiring auth information from
   256  // the Config, if available. As this might be slow (invoking EntryForRegistry
   257  // can end up invoking slow external commands), we ensure that it's only
   258  // done once.
   259  // TODO it's possible that this could take a very long time, during which
   260  // the outer context is cancelled, but we'll ignore that. We probably shouldn't.
   261  func (r *registry) init() error {
   262  	inner := func() error {
   263  		info, err := r.config.EntryForRegistry(r.host)
   264  		if err != nil {
   265  			return fmt.Errorf("cannot acquire auth info for registry %q: %v", r.host, err)
   266  		}
   267  		r.refreshToken = info.RefreshToken
   268  		if info.AccessToken != "" {
   269  			r.accessTokens = append(r.accessTokens, &scopedToken{
   270  				scope:   UnlimitedScope(),
   271  				token:   info.AccessToken,
   272  				expires: forever,
   273  			})
   274  		}
   275  		if info.Username != "" && info.Password != "" {
   276  			r.basic = &userPass{
   277  				username: info.Username,
   278  				password: info.Password,
   279  			}
   280  		}
   281  		return nil
   282  	}
   283  	r.initOnce.Do(func() {
   284  		r.initErr = inner()
   285  	})
   286  	return r.initErr
   287  }
   288  
   289  // acquireAccessToken tries to acquire an access token for authorizing a request.
   290  // The requiredScopeStr parameter indicates the scope that's definitely
   291  // required. This is a string because apparently some servers are picky
   292  // about getting exactly the same scope in the auth request that was
   293  // returned in the challenge. The wantScope parameter indicates
   294  // what scope might be required in the future.
   295  //
   296  // This method assumes that there has been a previous 401 response with
   297  // a Www-Authenticate: Bearer... header.
   298  func (r *registry) acquireAccessToken(ctx context.Context, requiredScope, wantScope Scope) (string, error) {
   299  	scope := requiredScope.Union(wantScope)
   300  	tok, err := r.acquireToken(ctx, scope)
   301  	if err != nil {
   302  		var rerr *responseError
   303  		if !errors.As(err, &rerr) || rerr.statusCode != http.StatusUnauthorized {
   304  			return "", err
   305  		}
   306  		// The documentation says this:
   307  		//
   308  		//	If the client only has a subset of the requested
   309  		// 	access it _must not be considered an error_ as it is
   310  		//	not the responsibility of the token server to
   311  		//	indicate authorization errors as part of this
   312  		//	workflow.
   313  		//
   314  		// However it's apparently not uncommon for servers to reject
   315  		// such requests anyway, so if we've got an unauthorized error
   316  		// and wantScope goes beyond requiredScope, it may be because
   317  		// the server is rejecting the request.
   318  		scope = requiredScope
   319  		tok, err = r.acquireToken(ctx, scope)
   320  		if err != nil {
   321  			return "", err
   322  		}
   323  		// TODO mark the registry as picky about tokens so we don't
   324  		// attempt twice every time?
   325  	}
   326  	if tok.RefreshToken != "" {
   327  		r.refreshToken = tok.RefreshToken
   328  	}
   329  	accessToken := tok.Token
   330  	if accessToken == "" {
   331  		accessToken = tok.AccessToken
   332  	}
   333  	if accessToken == "" {
   334  		return "", fmt.Errorf("no access token found in auth server response")
   335  	}
   336  	var expires time.Time
   337  	now := time.Now().UTC()
   338  	if tok.ExpiresIn == 0 {
   339  		expires = now.Add(60 * time.Second) // TODO link to where this is mentioned
   340  	} else {
   341  		expires = now.Add(time.Duration(tok.ExpiresIn) * time.Second)
   342  	}
   343  	r.accessTokens = append(r.accessTokens, &scopedToken{
   344  		scope:   scope,
   345  		token:   accessToken,
   346  		expires: expires,
   347  	})
   348  	// TODO persist the access token to save round trips when doing
   349  	// the authorization flow in a newly run executable.
   350  	return accessToken, nil
   351  }
   352  
   353  func (r *registry) acquireToken(ctx context.Context, scope Scope) (*wireToken, error) {
   354  	realm := r.wwwAuthenticate.params["realm"]
   355  	if realm == "" {
   356  		return nil, fmt.Errorf("malformed Www-Authenticate header (missing realm)")
   357  	}
   358  	if r.refreshToken != "" {
   359  		v := url.Values{}
   360  		v.Set("scope", scope.String())
   361  		if service := r.wwwAuthenticate.params["service"]; service != "" {
   362  			v.Set("service", service)
   363  		}
   364  		v.Set("client_id", oauthClientID)
   365  		v.Set("grant_type", "refresh_token")
   366  		v.Set("refresh_token", r.refreshToken)
   367  		req, err := http.NewRequestWithContext(ctx, "POST", realm, strings.NewReader(v.Encode()))
   368  		if err != nil {
   369  			return nil, fmt.Errorf("cannot form HTTP request to %q: %v", realm, err)
   370  		}
   371  		req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   372  		tok, err := r.doTokenRequest(req)
   373  		if err == nil {
   374  			return tok, nil
   375  		}
   376  		var rerr *responseError
   377  		if !errors.As(err, &rerr) || rerr.statusCode != http.StatusNotFound {
   378  			return tok, err
   379  		}
   380  		// The request to the endpoint returned 404 from the POST request,
   381  		// Note: Not all token servers implement oauth2, so fall
   382  		// back to using a GET with basic auth.
   383  		// See the Token documentation for the HTTP GET method supported by all token servers.
   384  		// TODO where in that documentation is this documented?
   385  	}
   386  	u, err := url.Parse(realm)
   387  	if err != nil {
   388  		return nil, fmt.Errorf("malformed Www-Authenticate header (malformed realm %q): %v", realm, err)
   389  	}
   390  	v := u.Query()
   391  	// TODO where is it documented that we should send multiple scope
   392  	// attributes rather than a single space-separated attribute as
   393  	// the POST method does?
   394  	v["scope"] = strings.Split(scope.String(), " ")
   395  	if service := r.wwwAuthenticate.params["service"]; service != "" {
   396  		// TODO the containerregistry code sets this even if it's empty.
   397  		// Is that better?
   398  		v.Set("service", service)
   399  	}
   400  	u.RawQuery = v.Encode()
   401  	req, err := http.NewRequest("GET", u.String(), nil)
   402  	if err != nil {
   403  		return nil, err
   404  	}
   405  	// TODO if there's an unlimited-scope access token, the original code
   406  	// will use it as Bearer authorization at this point. If
   407  	// that's valid, why are we even acquiring another token?
   408  	if r.basic != nil {
   409  		req.SetBasicAuth(r.basic.username, r.basic.password)
   410  	}
   411  	return r.doTokenRequest(req)
   412  }
   413  
   414  // wireToken describes the JSON encoding used in the response to a token
   415  // acquisition method. The comments are taken from the [token docs]
   416  // and made available here for ease of reference.
   417  //
   418  // [token docs]: https://distribution.github.io/distribution/spec/auth/token/#token-response-fields
   419  type wireToken struct {
   420  	// Token holds an opaque Bearer token that clients should supply
   421  	// to subsequent requests in the Authorization header.
   422  	// AccessToken is provided for compatibility with OAuth 2.0: it's equivalent to Token.
   423  	// At least one of these fields must be specified, but both may also appear (for compatibility with older clients).
   424  	// When both are specified, they should be equivalent; if they differ the client's choice is undefined.
   425  	Token       string `json:"token"`
   426  	AccessToken string `json:"access_token,omitempty"`
   427  
   428  	// Refresh token optionally holds a token which can be used to
   429  	// get additional access tokens for the same subject with different scopes.
   430  	// This token should be kept secure by the client and only sent
   431  	// to the authorization server which issues bearer tokens. This
   432  	// field will only be set when `offline_token=true` is provided
   433  	// in the request.
   434  	RefreshToken string `json:"refresh_token"`
   435  
   436  	// ExpiresIn holds the duration in seconds since the token was
   437  	// issued that it will remain valid. When omitted, this defaults
   438  	// to 60 seconds. For compatibility with older clients, a token
   439  	// should never be returned with less than 60 seconds to live.
   440  	ExpiresIn int `json:"expires_in"`
   441  }
   442  
   443  func (r *registry) doTokenRequest(req *http.Request) (*wireToken, error) {
   444  	client := &http.Client{
   445  		Transport: r.transport,
   446  	}
   447  	resp, err := client.Do(req)
   448  	if err != nil {
   449  		return nil, err
   450  	}
   451  	defer resp.Body.Close()
   452  	if resp.StatusCode != http.StatusOK {
   453  		return nil, errorFromResponse(resp)
   454  	}
   455  	data, err := io.ReadAll(resp.Body)
   456  	if err != nil {
   457  		return nil, fmt.Errorf("cannot read response body: %v", err)
   458  	}
   459  	var tok wireToken
   460  	if err := json.Unmarshal(data, &tok); err != nil {
   461  		return nil, fmt.Errorf("malformed JSON token in response: %v", err)
   462  	}
   463  	return &tok, nil
   464  }
   465  
   466  type responseError struct {
   467  	statusCode int
   468  	msg        string
   469  }
   470  
   471  func errorFromResponse(resp *http.Response) error {
   472  	// TODO include body of response in error message.
   473  	return &responseError{
   474  		statusCode: resp.StatusCode,
   475  	}
   476  }
   477  
   478  func (e *responseError) Error() string {
   479  	return fmt.Sprintf("unexpected HTTP response %d", e.statusCode)
   480  }
   481  
   482  // deleteExpiredTokens removes all tokens from r that expire after the given
   483  // time.
   484  // TODO ask the store to remove expired tokens?
   485  func (r *registry) deleteExpiredTokens(now time.Time) {
   486  	r.accessTokens = slices.DeleteFunc(r.accessTokens, func(tok *scopedToken) bool {
   487  		return now.After(tok.expires)
   488  	})
   489  }
   490  
   491  func (r *registry) accessTokenForScope(scope Scope) *scopedToken {
   492  	for _, tok := range r.accessTokens {
   493  		if tok.scope.Contains(scope) {
   494  			// TODO prefer tokens with less scope?
   495  			return tok
   496  		}
   497  	}
   498  	return nil
   499  }
   500  
   501  type emptyConfig struct{}
   502  
   503  func (emptyConfig) EntryForRegistry(host string) (ConfigEntry, error) {
   504  	return ConfigEntry{}, nil
   505  }
   506  

View as plain text