...

Source file src/cuelabs.dev/go/oci/ociregistry/ociclient/client.go

Documentation: cuelabs.dev/go/oci/ociregistry/ociclient

     1  // Copyright 2023 CUE Labs AG
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package ociclient provides an implementation of ociregistry.Interface that
    16  // uses HTTP to talk to the remote registry.
    17  package ociclient
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"hash"
    24  	"io"
    25  	"log"
    26  	"net/http"
    27  	"net/url"
    28  	"strconv"
    29  	"strings"
    30  	"sync/atomic"
    31  
    32  	"github.com/opencontainers/go-digest"
    33  	ocispec "github.com/opencontainers/image-spec/specs-go/v1"
    34  
    35  	"cuelabs.dev/go/oci/ociregistry"
    36  	"cuelabs.dev/go/oci/ociregistry/internal/ocirequest"
    37  	"cuelabs.dev/go/oci/ociregistry/ociauth"
    38  )
    39  
    40  // debug enables logging.
    41  // TODO this should be configurable in the API.
    42  const debug = false
    43  
    44  type Options struct {
    45  	// DebugID is used to prefix any log messages printed by the client.
    46  	DebugID string
    47  
    48  	// Transport is used to make HTTP requests. The context passed
    49  	// to its RoundTrip method will have an appropriate
    50  	// [ociauth.RequestInfo] value added, suitable for consumption
    51  	// by the transport created by [ociauth.NewStdTransport]. If
    52  	// Transport is nil, [http.DefaultTransport] will be used.
    53  	Transport http.RoundTripper
    54  
    55  	// Insecure specifies whether an http scheme will be used to
    56  	// address the host instead of https.
    57  	Insecure bool
    58  
    59  	// ListPageSize configures the maximum number of results
    60  	// requested when making list requests. If it's <= zero, it
    61  	// defaults to DefaultListPageSize.
    62  	ListPageSize int
    63  }
    64  
    65  // See https://github.com/google/go-containerregistry/issues/1091
    66  // for an early report of the issue alluded to below.
    67  
    68  // DefaultListPageSize holds the default number of results
    69  // to request when using the list endpoints.
    70  // It's not more than 1000 because AWS ECR complains
    71  // it it's more than that.
    72  const DefaultListPageSize = 1000
    73  
    74  var debugID int32
    75  
    76  // New returns a registry implementation that uses the OCI
    77  // HTTP API. A nil opts parameter is equivalent to a pointer
    78  // to zero Options.
    79  //
    80  // The host specifies the host name to talk to; it may
    81  // optionally be a host:port pair.
    82  func New(host string, opts0 *Options) (ociregistry.Interface, error) {
    83  	var opts Options
    84  	if opts0 != nil {
    85  		opts = *opts0
    86  	}
    87  	if opts.DebugID == "" {
    88  		opts.DebugID = fmt.Sprintf("id%d", atomic.AddInt32(&debugID, 1))
    89  	}
    90  	if opts.Transport == nil {
    91  		opts.Transport = http.DefaultTransport
    92  	}
    93  	// Check that it's a valid host by forming a URL from it and checking that it matches.
    94  	u, err := url.Parse("https://" + host + "/path")
    95  	if err != nil {
    96  		return nil, fmt.Errorf("invalid host %q", host)
    97  	}
    98  	if u.Host != host {
    99  		return nil, fmt.Errorf("invalid host %q (does not correctly form a host part of a URL)", host)
   100  	}
   101  	if opts.Insecure {
   102  		u.Scheme = "http"
   103  	}
   104  	if opts.ListPageSize == 0 {
   105  		opts.ListPageSize = DefaultListPageSize
   106  	}
   107  	return &client{
   108  		httpHost:   host,
   109  		httpScheme: u.Scheme,
   110  		httpClient: &http.Client{
   111  			Transport: opts.Transport,
   112  		},
   113  		debugID:      opts.DebugID,
   114  		listPageSize: opts.ListPageSize,
   115  	}, nil
   116  }
   117  
   118  type client struct {
   119  	*ociregistry.Funcs
   120  	httpScheme   string
   121  	httpHost     string
   122  	httpClient   *http.Client
   123  	debugID      string
   124  	listPageSize int
   125  }
   126  
   127  // descriptorFromResponse tries to form a descriptor from an HTTP response,
   128  // filling in the Digest field using knownDigest if it's not present.
   129  //
   130  // Note: this implies that the Digest field will be empty if there is no
   131  // digest in the response and knownDigest is empty.
   132  func descriptorFromResponse(resp *http.Response, knownDigest digest.Digest, requireSize bool) (ociregistry.Descriptor, error) {
   133  	contentType := resp.Header.Get("Content-Type")
   134  	if contentType == "" {
   135  		contentType = "application/octet-stream"
   136  	}
   137  	size := int64(0)
   138  	if requireSize {
   139  		if resp.StatusCode == http.StatusPartialContent {
   140  			contentRange := resp.Header.Get("Content-Range")
   141  			if contentRange == "" {
   142  				return ociregistry.Descriptor{}, fmt.Errorf("no Content-Range in partial content response")
   143  			}
   144  			i := strings.LastIndex(contentRange, "/")
   145  			if i == -1 {
   146  				return ociregistry.Descriptor{}, fmt.Errorf("malformed Content-Range %q", contentRange)
   147  			}
   148  			contentSize, err := strconv.ParseInt(contentRange[i+1:], 10, 64)
   149  			if err != nil {
   150  				return ociregistry.Descriptor{}, fmt.Errorf("malformed Content-Range %q", contentRange)
   151  			}
   152  			size = contentSize
   153  		} else {
   154  			if resp.ContentLength < 0 {
   155  				return ociregistry.Descriptor{}, fmt.Errorf("unknown content length")
   156  			}
   157  			size = resp.ContentLength
   158  		}
   159  	}
   160  	digest := digest.Digest(resp.Header.Get("Docker-Content-Digest"))
   161  	if digest != "" {
   162  		if !ociregistry.IsValidDigest(string(digest)) {
   163  			return ociregistry.Descriptor{}, fmt.Errorf("bad digest %q found in response", digest)
   164  		}
   165  	} else {
   166  		digest = knownDigest
   167  	}
   168  	return ociregistry.Descriptor{
   169  		Digest:    digest,
   170  		MediaType: contentType,
   171  		Size:      size,
   172  	}, nil
   173  }
   174  
   175  func newBlobReader(r io.ReadCloser, desc ociregistry.Descriptor) *blobReader {
   176  	return &blobReader{
   177  		r:        r,
   178  		digester: desc.Digest.Algorithm().Hash(),
   179  		desc:     desc,
   180  		verify:   true,
   181  	}
   182  }
   183  
   184  func newBlobReaderUnverified(r io.ReadCloser, desc ociregistry.Descriptor) *blobReader {
   185  	br := newBlobReader(r, desc)
   186  	br.verify = false
   187  	return br
   188  }
   189  
   190  type blobReader struct {
   191  	r        io.ReadCloser
   192  	n        int64
   193  	digester hash.Hash
   194  	desc     ociregistry.Descriptor
   195  	verify   bool
   196  }
   197  
   198  func (r *blobReader) Descriptor() ociregistry.Descriptor {
   199  	return r.desc
   200  }
   201  
   202  func (r *blobReader) Read(buf []byte) (int, error) {
   203  	n, err := r.r.Read(buf)
   204  	r.n += int64(n)
   205  	r.digester.Write(buf[:n])
   206  	if err == nil {
   207  		if r.n > r.desc.Size {
   208  			// Fail early when the blob is too big; we can do that even
   209  			// when we're not verifying for other use cases.
   210  			return n, fmt.Errorf("blob size exceeds content length %d: %w", r.desc.Size, ociregistry.ErrSizeInvalid)
   211  		}
   212  		return n, nil
   213  	}
   214  	if err != io.EOF {
   215  		return n, err
   216  	}
   217  	if !r.verify {
   218  		return n, io.EOF
   219  	}
   220  	if r.n != r.desc.Size {
   221  		return n, fmt.Errorf("blob size mismatch (%d/%d): %w", r.n, r.desc.Size, ociregistry.ErrSizeInvalid)
   222  	}
   223  	gotDigest := digest.NewDigest(r.desc.Digest.Algorithm(), r.digester)
   224  	if gotDigest != r.desc.Digest {
   225  		return n, fmt.Errorf("digest mismatch when reading blob")
   226  	}
   227  	return n, io.EOF
   228  }
   229  
   230  func (r *blobReader) Close() error {
   231  	return r.r.Close()
   232  }
   233  
   234  // TODO make this list configurable.
   235  var knownManifestMediaTypes = []string{
   236  	ocispec.MediaTypeImageManifest,
   237  	ocispec.MediaTypeImageIndex,
   238  	"application/vnd.oci.artifact.manifest.v1+json", // deprecated.
   239  	"application/vnd.docker.distribution.manifest.v1+json",
   240  	"application/vnd.docker.distribution.manifest.v2+json",
   241  	"application/vnd.docker.distribution.manifest.list.v2+json",
   242  	// Technically this wildcard should be sufficient, but it isn't
   243  	// recognized by some registries.
   244  	"*/*",
   245  }
   246  
   247  // doRequest performs the given OCI request, sending it with the given body (which may be nil).
   248  func (c *client) doRequest(ctx context.Context, rreq *ocirequest.Request, okStatuses ...int) (*http.Response, error) {
   249  	req, err := newRequest(ctx, rreq, nil)
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	if rreq.Kind == ocirequest.ReqManifestGet || rreq.Kind == ocirequest.ReqManifestHead {
   254  		// When getting manifests, some servers won't return
   255  		// the content unless there's an Accept header, so
   256  		// add all the manifest kinds that we know about.
   257  		req.Header["Accept"] = knownManifestMediaTypes
   258  	}
   259  	resp, err := c.do(req, okStatuses...)
   260  	if err != nil {
   261  		return nil, err
   262  	}
   263  	if resp.StatusCode/100 == 2 {
   264  		return resp, nil
   265  	}
   266  	defer resp.Body.Close()
   267  	return nil, makeError(resp)
   268  }
   269  
   270  func (c *client) do(req *http.Request, okStatuses ...int) (*http.Response, error) {
   271  	if req.URL.Scheme == "" {
   272  		req.URL.Scheme = c.httpScheme
   273  	}
   274  	if req.URL.Host == "" {
   275  		req.URL.Host = c.httpHost
   276  	}
   277  	if req.Body != nil {
   278  		// Ensure that the body isn't consumed until the
   279  		// server has responded that it will receive it.
   280  		// This means that we can retry requests even when we've
   281  		// got a consume-once-only io.Reader, such as
   282  		// when pushing blobs.
   283  		req.Header.Set("Expect", "100-continue")
   284  	}
   285  	var buf bytes.Buffer
   286  	if debug {
   287  		fmt.Fprintf(&buf, "client.Do: %s %s {{\n", req.Method, req.URL)
   288  		fmt.Fprintf(&buf, "\tBODY: %#v\n", req.Body)
   289  		for k, v := range req.Header {
   290  			fmt.Fprintf(&buf, "\t%s: %q\n", k, v)
   291  		}
   292  		c.logf("%s", buf.Bytes())
   293  	}
   294  	resp, err := c.httpClient.Do(req)
   295  	if err != nil {
   296  		return nil, fmt.Errorf("cannot do HTTP request: %w", err)
   297  	}
   298  	if debug {
   299  		buf.Reset()
   300  		fmt.Fprintf(&buf, "} -> %s {\n", resp.Status)
   301  		for k, v := range resp.Header {
   302  			fmt.Fprintf(&buf, "\t%s: %q\n", k, v)
   303  		}
   304  		data, _ := io.ReadAll(resp.Body)
   305  		if len(data) > 0 {
   306  			fmt.Fprintf(&buf, "\tBODY: %q\n", data)
   307  		}
   308  		fmt.Fprintf(&buf, "}}\n")
   309  		resp.Body.Close()
   310  		resp.Body = io.NopCloser(bytes.NewReader(data))
   311  		c.logf("%s", buf.Bytes())
   312  	}
   313  	if len(okStatuses) == 0 && resp.StatusCode == http.StatusOK {
   314  		return resp, nil
   315  	}
   316  	for _, status := range okStatuses {
   317  		if resp.StatusCode == status {
   318  			return resp, nil
   319  		}
   320  	}
   321  	defer resp.Body.Close()
   322  	if !isOKStatus(resp.StatusCode) {
   323  		return nil, makeError(resp)
   324  	}
   325  	return nil, unexpectedStatusError(resp.StatusCode)
   326  }
   327  
   328  func (c *client) logf(f string, a ...any) {
   329  	log.Printf("ociclient %s: %s", c.debugID, fmt.Sprintf(f, a...))
   330  }
   331  
   332  func locationFromResponse(resp *http.Response) (*url.URL, error) {
   333  	location := resp.Header.Get("Location")
   334  	if location == "" {
   335  		return nil, fmt.Errorf("no Location found in response")
   336  	}
   337  	u, err := url.Parse(location)
   338  	if err != nil {
   339  		return nil, fmt.Errorf("invalid Location URL found in response")
   340  	}
   341  	return resp.Request.URL.ResolveReference(u), nil
   342  }
   343  
   344  func isOKStatus(code int) bool {
   345  	return code/100 == 2
   346  }
   347  
   348  func closeOnError(err *error, r io.Closer) {
   349  	if *err != nil {
   350  		r.Close()
   351  	}
   352  }
   353  
   354  func unexpectedStatusError(code int) error {
   355  	return fmt.Errorf("unexpected HTTP response code %d", code)
   356  }
   357  
   358  func scopeForRequest(r *ocirequest.Request) ociauth.Scope {
   359  	switch r.Kind {
   360  	case ocirequest.ReqPing:
   361  		return ociauth.Scope{}
   362  	case ocirequest.ReqBlobGet,
   363  		ocirequest.ReqBlobHead,
   364  		ocirequest.ReqManifestGet,
   365  		ocirequest.ReqManifestHead,
   366  		ocirequest.ReqTagsList,
   367  		ocirequest.ReqReferrersList:
   368  		return ociauth.NewScope(ociauth.ResourceScope{
   369  			ResourceType: ociauth.TypeRepository,
   370  			Resource:     r.Repo,
   371  			Action:       ociauth.ActionPull,
   372  		})
   373  	case ocirequest.ReqBlobDelete,
   374  		ocirequest.ReqBlobStartUpload,
   375  		ocirequest.ReqBlobUploadBlob,
   376  		ocirequest.ReqBlobUploadInfo,
   377  		ocirequest.ReqBlobUploadChunk,
   378  		ocirequest.ReqBlobCompleteUpload,
   379  		ocirequest.ReqManifestPut,
   380  		ocirequest.ReqManifestDelete:
   381  		return ociauth.NewScope(ociauth.ResourceScope{
   382  			ResourceType: ociauth.TypeRepository,
   383  			Resource:     r.Repo,
   384  			Action:       ociauth.ActionPush,
   385  		})
   386  	case ocirequest.ReqBlobMount:
   387  		return ociauth.NewScope(ociauth.ResourceScope{
   388  			ResourceType: ociauth.TypeRepository,
   389  			Resource:     r.Repo,
   390  			Action:       ociauth.ActionPush,
   391  		}, ociauth.ResourceScope{
   392  			ResourceType: ociauth.TypeRepository,
   393  			Resource:     r.FromRepo,
   394  			Action:       ociauth.ActionPull,
   395  		})
   396  	case ocirequest.ReqCatalogList:
   397  		return ociauth.NewScope(ociauth.CatalogScope)
   398  	default:
   399  		panic(fmt.Errorf("unexpected request kind %v", r.Kind))
   400  	}
   401  }
   402  
   403  func newRequest(ctx context.Context, rreq *ocirequest.Request, body io.Reader) (*http.Request, error) {
   404  	method, u, err := rreq.Construct()
   405  	if err != nil {
   406  		return nil, err
   407  	}
   408  	ctx = ociauth.ContextWithRequestInfo(ctx, ociauth.RequestInfo{
   409  		RequiredScope: scopeForRequest(rreq),
   410  	})
   411  	return http.NewRequestWithContext(ctx, method, u, body)
   412  }
   413  

View as plain text