    23  package estargz
    25  import (
    26  	"archive/tar"
    27  	"bytes"
    28  	"compress/gzip"
    29  	"context"
    30  	"errors"
    31  	"fmt"
    32  	"io"
    33  	"os"
    34  	"path"
    35  	"runtime"
    36  	"strings"
    37  	"sync"
    39  	"github.com/containerd/stargz-snapshotter/estargz/errorutil"
    40  	"github.com/klauspost/compress/zstd"
    41  	digest "github.com/opencontainers/go-digest"
    42  	"golang.org/x/sync/errgroup"
    43  )
    45  type options struct {
    46  	chunkSize              int
    47  	compressionLevel       int
    48  	prioritizedFiles       []string
    49  	missedPrioritizedFiles *[]string
    50  	compression            Compression
    51  	ctx                    context.Context
    52  	minChunkSize           int
    53  }
    55  type Option func(o *options) error
    57  // WithChunkSize option specifies the chunk size of eStargz blob to build.
    58  func WithChunkSize(chunkSize int) Option {
    59  	return func(o *options) error {
    60  		o.chunkSize = chunkSize
    61  		return nil
    62  	}
    63  }
    65  // WithCompressionLevel option specifies the gzip compression level.
    66  // The default is gzip.BestCompression.
    67  // This option will be ignored if WithCompression option is used.
    68  // See also: https://godoc.org/compress/gzip#pkg-constants
    69  func WithCompressionLevel(level int) Option {
    70  	return func(o *options) error {
    71  		o.compressionLevel = level
    72  		return nil
    73  	}
    74  }
    76  // WithPrioritizedFiles option specifies the list of prioritized files.
    77  // These files must be complete paths that are absolute or relative to "/"
    78  // For example, all of "foo/bar", "/foo/bar", "./foo/bar" and "../foo/bar"
    79  // are treated as "/foo/bar".
    80  func WithPrioritizedFiles(files []string) Option {
    81  	return func(o *options) error {
    82  		o.prioritizedFiles = files
    83  		return nil
    84  	}
    85  }
    87  // WithAllowPrioritizeNotFound makes Build continue the execution even if some
    88  // of prioritized files specified by WithPrioritizedFiles option aren't found
    89  // in the input tar. Instead, this records all missed file names to the passed
    90  // slice.
    91  func WithAllowPrioritizeNotFound(missedFiles *[]string) Option {
    92  	return func(o *options) error {
    93  		if missedFiles == nil {
    94  			return fmt.Errorf("WithAllowPrioritizeNotFound: slice must be passed")
    95  		}
    96  		o.missedPrioritizedFiles = missedFiles
    97  		return nil
    98  	}
    99  }
   101  // WithCompression specifies compression algorithm to be used.
   102  // Default is gzip.
   103  func WithCompression(compression Compression) Option {
   104  	return func(o *options) error {
   105  		o.compression = compression
   106  		return nil
   107  	}
   108  }
   110  // WithContext specifies a context that can be used for clean canceleration.
   111  func WithContext(ctx context.Context) Option {
   112  	return func(o *options) error {
   113  		o.ctx = ctx
   114  		return nil
   115  	}
   116  }
   118  // WithMinChunkSize option specifies the minimal number of bytes of data
   119  // must be written in one gzip stream.
   120  // By increasing this number, one gzip stream can contain multiple files
   121  // and it hopefully leads to smaller result blob.
   122  // NOTE: This adds a TOC property that old reader doesn't understand.
   123  func WithMinChunkSize(minChunkSize int) Option {
   124  	return func(o *options) error {
   125  		o.minChunkSize = minChunkSize
   126  		return nil
   127  	}
   128  }
   130  // Blob is an eStargz blob.
   131  type Blob struct {
   132  	io.ReadCloser
   133  	diffID    digest.Digester
   134  	tocDigest digest.Digest
   135  }
   137  // DiffID returns the digest of uncompressed blob.
   138  // It is only valid to call DiffID after Close.
   139  func (b *Blob) DiffID() digest.Digest {
   140  	return b.diffID.Digest()
   141  }
   143  // TOCDigest returns the digest of uncompressed TOC JSON.
   144  func (b *Blob) TOCDigest() digest.Digest {
   145  	return b.tocDigest
   146  }
   148  // Build builds an eStargz blob which is an extended version of stargz, from a blob (gzip, zstd
   149  // or plain tar) passed through the argument. If there are some prioritized files are listed in
   150  // the option, these files are grouped as "prioritized" and can be used for runtime optimization
   151  // (e.g. prefetch). This function builds a blob in parallel, with dividing that blob into several
   152  // (at least the number of runtime.GOMAXPROCS(0)) sub-blobs.
   153  func Build(tarBlob *io.SectionReader, opt ...Option) (_ *Blob, rErr error) {
   154  	var opts options
   155  	opts.compressionLevel = gzip.BestCompression // BestCompression by default
   156  	for _, o := range opt {
   157  		if err := o(&opts); err != nil {
   158  			return nil, err
   159  		}
   160  	}
   161  	if opts.compression == nil {
   162  		opts.compression = newGzipCompressionWithLevel(opts.compressionLevel)
   163  	}
   164  	layerFiles := newTempFiles()
   165  	ctx := opts.ctx
   166  	if ctx == nil {
   167  		ctx = context.Background()
   168  	}
   169  	done := make(chan struct{})
   170  	defer close(done)
   171  	go func() {
   172  		select {
   173  		case <-done:
   174  			// nop
   175  		case <-ctx.Done():
   176  			layerFiles.CleanupAll()
   177  		}
   178  	}()
   179  	defer func() {
   180  		if rErr != nil {
   181  			if err := layerFiles.CleanupAll(); err != nil {
   182  				rErr = fmt.Errorf("failed to cleanup tmp files: %v: %w", err, rErr)
   183  			}
   184  		}
   185  		if cErr := ctx.Err(); cErr != nil {
   186  			rErr = fmt.Errorf("error from context %q: %w", cErr, rErr)
   187  		}
   188  	}()
   189  	tarBlob, err := decompressBlob(tarBlob, layerFiles)
   190  	if err != nil {
   191  		return nil, err
   192  	}
   193  	entries, err := sortEntries(tarBlob, opts.prioritizedFiles, opts.missedPrioritizedFiles)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  	var tarParts [][]*entry
   198  	if opts.minChunkSize > 0 {
   199  		// Each entry needs to know the size of the current gzip stream so they
   200  		// cannot be processed in parallel.
   201  		tarParts = [][]*entry{entries}
   202  	} else {
   203  		tarParts = divideEntries(entries, runtime.GOMAXPROCS(0))
   204  	}
   205  	writers := make([]*Writer, len(tarParts))
   206  	payloads := make([]*os.File, len(tarParts))
   207  	var mu sync.Mutex
   208  	var eg errgroup.Group
   209  	for i, parts := range tarParts {
   210  		i, parts := i, parts
   211  		// builds verifiable stargz sub-blobs
   212  		eg.Go(func() error {
   213  			esgzFile, err := layerFiles.TempFile("", "esgzdata")
   214  			if err != nil {
   215  				return err
   216  			}
   217  			sw := NewWriterWithCompressor(esgzFile, opts.compression)
   218  			sw.ChunkSize = opts.chunkSize
   219  			sw.MinChunkSize = opts.minChunkSize
   220  			if sw.needsOpenGzEntries == nil {
   221  				sw.needsOpenGzEntries = make(map[string]struct{})
   222  			}
   223  			for _, f := range []string{PrefetchLandmark, NoPrefetchLandmark} {
   224  				sw.needsOpenGzEntries[f] = struct{}{}
   225  			}
   226  			if err := sw.AppendTar(readerFromEntries(parts...)); err != nil {
   227  				return err
   228  			}
   229  			mu.Lock()
   230  			writers[i] = sw
   231  			payloads[i] = esgzFile
   232  			mu.Unlock()
   233  			return nil
   234  		})
   235  	}
   236  	if err := eg.Wait(); err != nil {
   237  		rErr = err
   238  		return nil, err
   239  	}
   240  	tocAndFooter, tocDgst, err := closeWithCombine(writers...)
   241  	if err != nil {
   242  		rErr = err
   243  		return nil, err
   244  	}
   245  	var rs []io.Reader
   246  	for _, p := range payloads {
   247  		fs, err := fileSectionReader(p)
   248  		if err != nil {
   249  			return nil, err
   250  		}
   251  		rs = append(rs, fs)
   252  	}
   253  	diffID := digest.Canonical.Digester()
   254  	pr, pw := io.Pipe()
   255  	go func() {
   256  		r, err := opts.compression.Reader(io.TeeReader(io.MultiReader(append(rs, tocAndFooter)...), pw))
   257  		if err != nil {
   258  			pw.CloseWithError(err)
   259  			return
   260  		}
   261  		defer r.Close()
   262  		if _, err := io.Copy(diffID.Hash(), r); err != nil {
   263  			pw.CloseWithError(err)
   264  			return
   265  		}
   266  		pw.Close()
   267  	}()
   268  	return &Blob{
   269  		ReadCloser: readCloser{
   270  			Reader:    pr,
   271  			closeFunc: layerFiles.CleanupAll,
   272  		},
   273  		tocDigest: tocDgst,
   274  		diffID:    diffID,
   275  	}, nil
   276  }
   278  // closeWithCombine takes unclosed Writers and close them. This also returns the
   279  // toc that combined all Writers into.
   280  // Writers doesn't write TOC and footer to the underlying writers so they can be
   281  // combined into a single eStargz and tocAndFooter returned by this function can
   282  // be appended at the tail of that combined blob.
   283  func closeWithCombine(ws ...*Writer) (tocAndFooterR io.Reader, tocDgst digest.Digest, err error) {
   284  	if len(ws) == 0 {
   285  		return nil, "", fmt.Errorf("at least one writer must be passed")
   286  	}
   287  	for _, w := range ws {
   288  		if w.closed {
   289  			return nil, "", fmt.Errorf("writer must be unclosed")
   290  		}
   291  		defer func(w *Writer) { w.closed = true }(w)
   292  		if err := w.closeGz(); err != nil {
   293  			return nil, "", err
   294  		}
   295  		if err := w.bw.Flush(); err != nil {
   296  			return nil, "", err
   297  		}
   298  	}
   299  	var (
   300  		mtoc          = new(JTOC)
   301  		currentOffset int64
   302  	)
   303  	mtoc.Version = ws[0].toc.Version
   304  	for _, w := range ws {
   305  		for _, e := range w.toc.Entries {
   306  			// Recalculate Offset of non-empty files/chunks
   307  			if (e.Type == "reg" && e.Size > 0) || e.Type == "chunk" {
   308  				e.Offset += currentOffset
   309  			}
   310  			mtoc.Entries = append(mtoc.Entries, e)
   311  		}
   312  		if w.toc.Version > mtoc.Version {
   313  			mtoc.Version = w.toc.Version
   314  		}
   315  		currentOffset += w.cw.n
   316  	}
   318  	return tocAndFooter(ws[0].compressor, mtoc, currentOffset)
   319  }
   321  func tocAndFooter(compressor Compressor, toc *JTOC, offset int64) (io.Reader, digest.Digest, error) {
   322  	buf := new(bytes.Buffer)
   323  	tocDigest, err := compressor.WriteTOCAndFooter(buf, offset, toc, nil)
   324  	if err != nil {
   325  		return nil, "", err
   326  	}
   327  	return buf, tocDigest, nil
   328  }
   330  // divideEntries divides passed entries to the parts at least the number specified by the
   331  // argument.
   332  func divideEntries(entries []*entry, minPartsNum int) (set [][]*entry) {
   333  	var estimatedSize int64
   334  	for _, e := range entries {
   335  		estimatedSize += e.header.Size
   336  	}
   337  	unitSize := estimatedSize / int64(minPartsNum)
   338  	var (
   339  		nextEnd = unitSize
   340  		offset  int64
   341  	)
   342  	set = append(set, []*entry{})
   343  	for _, e := range entries {
   344  		set[len(set)-1] = append(set[len(set)-1], e)
   345  		offset += e.header.Size
   346  		if offset > nextEnd {
   347  			set = append(set, []*entry{})
   348  			nextEnd += unitSize
   349  		}
   350  	}
   351  	return
   352  }
   354  var errNotFound = errors.New("not found")
   356  // sortEntries reads the specified tar blob and returns a list of tar entries.
   357  // If some of prioritized files are specified, the list starts from these
   358  // files with keeping the order specified by the argument.
   359  func sortEntries(in io.ReaderAt, prioritized []string, missedPrioritized *[]string) ([]*entry, error) {
   361  	// Import tar file.
   362  	intar, err := importTar(in)
   363  	if err != nil {
   364  		return nil, fmt.Errorf("failed to sort: %w", err)
   365  	}
   367  	// Sort the tar file respecting to the prioritized files list.
   368  	sorted := &tarFile{}
   369  	for _, l := range prioritized {
   370  		if err := moveRec(l, intar, sorted); err != nil {
   371  			if errors.Is(err, errNotFound) && missedPrioritized != nil {
   372  				*missedPrioritized = append(*missedPrioritized, l)
   373  				continue // allow not found
   374  			}
   375  			return nil, fmt.Errorf("failed to sort tar entries: %w", err)
   376  		}
   377  	}
   378  	if len(prioritized) == 0 {
   379  		sorted.add(&entry{
   380  			header: &tar.Header{
   381  				Name:     NoPrefetchLandmark,
   382  				Typeflag: tar.TypeReg,
   383  				Size:     int64(len([]byte{landmarkContents})),
   384  			},
   385  			payload: bytes.NewReader([]byte{landmarkContents}),
   386  		})
   387  	} else {
   388  		sorted.add(&entry{
   389  			header: &tar.Header{
   390  				Name:     PrefetchLandmark,
   391  				Typeflag: tar.TypeReg,
   392  				Size:     int64(len([]byte{landmarkContents})),
   393  			},
   394  			payload: bytes.NewReader([]byte{landmarkContents}),
   395  		})
   396  	}
   398  	// Dump all entry and concatinate them.
   399  	return append(sorted.dump(), intar.dump()...), nil
   400  }
   402  // readerFromEntries returns a reader of tar archive that contains entries passed
   403  // through the arguments.
   404  func readerFromEntries(entries ...*entry) io.Reader {
   405  	pr, pw := io.Pipe()
   406  	go func() {
   407  		tw := tar.NewWriter(pw)
   408  		defer tw.Close()
   409  		for _, entry := range entries {
   410  			if err := tw.WriteHeader(entry.header); err != nil {
   411  				pw.CloseWithError(fmt.Errorf("Failed to write tar header: %v", err))
   412  				return
   413  			}
   414  			if _, err := io.Copy(tw, entry.payload); err != nil {
   415  				pw.CloseWithError(fmt.Errorf("Failed to write tar payload: %v", err))
   416  				return
   417  			}
   418  		}
   419  		pw.Close()
   420  	}()
   421  	return pr
   422  }
   424  func importTar(in io.ReaderAt) (*tarFile, error) {
   425  	tf := &tarFile{}
   426  	pw, err := newCountReadSeeker(in)
   427  	if err != nil {
   428  		return nil, fmt.Errorf("failed to make position watcher: %w", err)
   429  	}
   430  	tr := tar.NewReader(pw)
   432  	// Walk through all nodes.
   433  	for {
   434  		// Fetch and parse next header.
   435  		h, err := tr.Next()
   436  		if err != nil {
   437  			if err == io.EOF {
   438  				break
   439  			} else {
   440  				return nil, fmt.Errorf("failed to parse tar file, %w", err)
   441  			}
   442  		}
   443  		switch cleanEntryName(h.Name) {
   444  		case PrefetchLandmark, NoPrefetchLandmark:
   445  			// Ignore existing landmark
   446  			continue
   447  		}
   449  		// Add entry. If it already exists, replace it.
   450  		if _, ok := tf.get(h.Name); ok {
   451  			tf.remove(h.Name)
   452  		}
   453  		tf.add(&entry{
   454  			header:  h,
   455  			payload: io.NewSectionReader(in, pw.currentPos(), h.Size),
   456  		})
   457  	}
   459  	return tf, nil
   460  }
   462  func moveRec(name string, in *tarFile, out *tarFile) error {
   463  	name = cleanEntryName(name)
   464  	if name == "" { // root directory. stop recursion.
   465  		if e, ok := in.get(name); ok {
   466  			// entry of the root directory exists. we should move it as well.
   467  			// this case will occur if tar entries are prefixed with "./", "/", etc.
   468  			out.add(e)
   469  			in.remove(name)
   470  		}
   471  		return nil
   472  	}
   474  	_, okIn := in.get(name)
   475  	_, okOut := out.get(name)
   476  	if !okIn && !okOut {
   477  		return fmt.Errorf("file: %q: %w", name, errNotFound)
   478  	}
   480  	parent, _ := path.Split(strings.TrimSuffix(name, "/"))
   481  	if err := moveRec(parent, in, out); err != nil {
   482  		return err
   483  	}
   484  	if e, ok := in.get(name); ok && e.header.Typeflag == tar.TypeLink {
   485  		if err := moveRec(e.header.Linkname, in, out); err != nil {
   486  			return err
   487  		}
   488  	}
   489  	if e, ok := in.get(name); ok {
   490  		out.add(e)
   491  		in.remove(name)
   492  	}
   493  	return nil
   494  }
   496  type entry struct {
   497  	header  *tar.Header
   498  	payload io.ReadSeeker
   499  }
   501  type tarFile struct {
   502  	index  map[string]*entry
   503  	stream []*entry
   504  }
   506  func (f *tarFile) add(e *entry) {
   507  	if f.index == nil {
   508  		f.index = make(map[string]*entry)
   509  	}
   510  	f.index[cleanEntryName(e.header.Name)] = e
   511  	f.stream = append(f.stream, e)
   512  }
   514  func (f *tarFile) remove(name string) {
   515  	name = cleanEntryName(name)
   516  	if f.index != nil {
   517  		delete(f.index, name)
   518  	}
   519  	var filtered []*entry
   520  	for _, e := range f.stream {
   521  		if cleanEntryName(e.header.Name) == name {
   522  			continue
   523  		}
   524  		filtered = append(filtered, e)
   525  	}
   526  	f.stream = filtered
   527  }
   529  func (f *tarFile) get(name string) (e *entry, ok bool) {
   530  	if f.index == nil {
   531  		return nil, false
   532  	}
   533  	e, ok = f.index[cleanEntryName(name)]
   534  	return
   535  }
   537  func (f *tarFile) dump() []*entry {
   538  	return f.stream
   539  }
   541  type readCloser struct {
   542  	io.Reader
   543  	closeFunc func() error
   544  }
   546  func (rc readCloser) Close() error {
   547  	return rc.closeFunc()
   548  }
   550  func fileSectionReader(file *os.File) (*io.SectionReader, error) {
   551  	info, err := file.Stat()
   552  	if err != nil {
   553  		return nil, err
   554  	}
   555  	return io.NewSectionReader(file, 0, info.Size()), nil
   556  }
   558  func newTempFiles() *tempFiles {
   559  	return &tempFiles{}
   560  }
   562  type tempFiles struct {
   563  	files       []*os.File
   564  	filesMu     sync.Mutex
   565  	cleanupOnce sync.Once
   566  }
   568  func (tf *tempFiles) TempFile(dir, pattern string) (*os.File, error) {
   569  	f, err := os.CreateTemp(dir, pattern)
   570  	if err != nil {
   571  		return nil, err
   572  	}
   573  	tf.filesMu.Lock()
   574  	tf.files = append(tf.files, f)
   575  	tf.filesMu.Unlock()
   576  	return f, nil
   577  }
   579  func (tf *tempFiles) CleanupAll() (err error) {
   580  	tf.cleanupOnce.Do(func() {
   581  		err = tf.cleanupAll()
   582  	})
   583  	return
   584  }
   586  func (tf *tempFiles) cleanupAll() error {
   587  	tf.filesMu.Lock()
   588  	defer tf.filesMu.Unlock()
   589  	var allErr []error
   590  	for _, f := range tf.files {
   591  		if err := f.Close(); err != nil {
   592  			allErr = append(allErr, err)
   593  		}
   594  		if err := os.Remove(f.Name()); err != nil {
   595  			allErr = append(allErr, err)
   596  		}
   597  	}
   598  	tf.files = nil
   599  	return errorutil.Aggregate(allErr)
   600  }
   602  func newCountReadSeeker(r io.ReaderAt) (*countReadSeeker, error) {
   603  	pos := int64(0)
   604  	return &countReadSeeker{r: r, cPos: &pos}, nil
   605  }
   607  type countReadSeeker struct {
   608  	r    io.ReaderAt
   609  	cPos *int64
   611  	mu sync.Mutex
   612  }
   614  func (cr *countReadSeeker) Read(p []byte) (int, error) {
   615  	cr.mu.Lock()
   616  	defer cr.mu.Unlock()
   618  	n, err := cr.r.ReadAt(p, *cr.cPos)
   619  	if err == nil {
   620  		*cr.cPos += int64(n)
   621  	}
   622  	return n, err
   623  }
   625  func (cr *countReadSeeker) Seek(offset int64, whence int) (int64, error) {
   626  	cr.mu.Lock()
   627  	defer cr.mu.Unlock()
   629  	switch whence {
   630  	default:
   631  		return 0, fmt.Errorf("Unknown whence: %v", whence)
   632  	case io.SeekStart:
   633  	case io.SeekCurrent:
   634  		offset += *cr.cPos
   635  	case io.SeekEnd:
   636  		return 0, fmt.Errorf("Unsupported whence: %v", whence)
   637  	}
   639  	if offset < 0 {
   640  		return 0, fmt.Errorf("invalid offset")
   641  	}
   642  	*cr.cPos = offset
   643  	return offset, nil
   644  }
   646  func (cr *countReadSeeker) currentPos() int64 {
   647  	cr.mu.Lock()
   648  	defer cr.mu.Unlock()
   650  	return *cr.cPos
   651  }
   653  func decompressBlob(org *io.SectionReader, tmp *tempFiles) (*io.SectionReader, error) {
   654  	if org.Size() < 4 {
   655  		return org, nil
   656  	}
   657  	src := make([]byte, 4)
   658  	if _, err := org.Read(src); err != nil && err != io.EOF {
   659  		return nil, err
   660  	}
   661  	var dR io.Reader
   662  	if bytes.Equal([]byte{0x1F, 0x8B, 0x08}, src[:3]) {
   663  		// gzip
   664  		dgR, err := gzip.NewReader(io.NewSectionReader(org, 0, org.Size()))
   665  		if err != nil {
   666  			return nil, err
   667  		}
   668  		defer dgR.Close()
   669  		dR = io.Reader(dgR)
   670  	} else if bytes.Equal([]byte{0x28, 0xb5, 0x2f, 0xfd}, src[:4]) {
   671  		// zstd
   672  		dzR, err := zstd.NewReader(io.NewSectionReader(org, 0, org.Size()))
   673  		if err != nil {
   674  			return nil, err
   675  		}
   676  		defer dzR.Close()
   677  		dR = io.Reader(dzR)
   678  	} else {
   679  		// uncompressed
   680  		return io.NewSectionReader(org, 0, org.Size()), nil
   681  	}
   682  	b, err := tmp.TempFile("", "uncompresseddata")
   683  	if err != nil {
   684  		return nil, err
   685  	}
   686  	if _, err := io.Copy(b, dR); err != nil {
   687  		return nil, err
   688  	}
   689  	return fileSectionReader(b)
   690  }

