...

Source file src/cuelang.org/go/mod/modzip/zip_test.go

Documentation: cuelang.org/go/mod/modzip

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package modzip_test
     6  
     7  import (
     8  	"archive/zip"
     9  	"bytes"
    10  	"fmt"
    11  	"io"
    12  	"os"
    13  	"path"
    14  	"path/filepath"
    15  	"runtime"
    16  	"strconv"
    17  	"strings"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/google/go-cmp/cmp"
    22  
    23  	"cuelang.org/go/internal/cuetest"
    24  	"cuelang.org/go/mod/module"
    25  	"cuelang.org/go/mod/modzip"
    26  	"golang.org/x/mod/sumdb/dirhash"
    27  	"golang.org/x/tools/txtar"
    28  )
    29  
    30  type testParams struct {
    31  	path, version, wantErr, hash string
    32  	want                         string
    33  	archive                      *txtar.Archive
    34  }
    35  
    36  // readTest loads a test from a txtar file. The comment section of the file
    37  // should contain lines with key=value pairs. Valid keys are the field names
    38  // from testParams.
    39  func readTest(file string) (testParams, error) {
    40  	var test testParams
    41  	var err error
    42  	test.archive, err = txtar.ParseFile(file)
    43  	if err != nil {
    44  		return testParams{}, err
    45  	}
    46  	for i, f := range test.archive.Files {
    47  		if f.Name == "want" {
    48  			test.want = string(f.Data)
    49  			test.archive.Files = append(test.archive.Files[:i], test.archive.Files[i+1:]...)
    50  			break
    51  		}
    52  	}
    53  
    54  	lines := strings.Split(string(test.archive.Comment), "\n")
    55  	for n, line := range lines {
    56  		n++ // report line numbers starting with 1
    57  		line = strings.TrimSpace(line)
    58  		if line == "" || line[0] == '#' {
    59  			continue
    60  		}
    61  		eq := strings.IndexByte(line, '=')
    62  		if eq < 0 {
    63  			return testParams{}, fmt.Errorf("%s:%d: missing = separator", file, n)
    64  		}
    65  		key, value := strings.TrimSpace(line[:eq]), strings.TrimSpace(line[eq+1:])
    66  		if strings.HasPrefix(value, "\"") {
    67  			unq, err := strconv.Unquote(value)
    68  			if err != nil {
    69  				return testParams{}, fmt.Errorf("%s:%d: %v", file, n, err)
    70  			}
    71  			value = unq
    72  		}
    73  		switch key {
    74  		case "path":
    75  			test.path = value
    76  		case "version":
    77  			test.version = value
    78  		case "wantErr":
    79  			test.wantErr = value
    80  		case "hash":
    81  			test.hash = value
    82  		default:
    83  			return testParams{}, fmt.Errorf("%s:%d: unknown key %q", file, n, key)
    84  		}
    85  	}
    86  
    87  	return test, nil
    88  }
    89  
    90  func extractTxtarToTempDir(t testing.TB, arc *txtar.Archive) (dir string, err error) {
    91  	dir = t.TempDir()
    92  	for _, f := range arc.Files {
    93  		filePath := filepath.Join(dir, f.Name)
    94  		if err := os.MkdirAll(filepath.Dir(filePath), 0777); err != nil {
    95  			return "", err
    96  		}
    97  		if err := os.WriteFile(filePath, f.Data, 0666); err != nil {
    98  			return "", err
    99  		}
   100  	}
   101  	return dir, nil
   102  }
   103  
   104  func extractTxtarToTempZip(t *testing.T, arc *txtar.Archive) (zipPath string, err error) {
   105  	zipPath = filepath.Join(t.TempDir(), "txtar.zip")
   106  
   107  	zipFile, err := os.Create(zipPath)
   108  	if err != nil {
   109  		return "", err
   110  	}
   111  	defer func() {
   112  		if cerr := zipFile.Close(); err == nil && cerr != nil {
   113  			err = cerr
   114  		}
   115  	}()
   116  
   117  	zw := zip.NewWriter(zipFile)
   118  	for _, f := range arc.Files {
   119  		zf, err := zw.Create(f.Name)
   120  		if err != nil {
   121  			return "", err
   122  		}
   123  		if _, err := zf.Write(f.Data); err != nil {
   124  			return "", err
   125  		}
   126  	}
   127  	if err := zw.Close(); err != nil {
   128  		return "", err
   129  	}
   130  	return zipFile.Name(), nil
   131  }
   132  
   133  type fakeFileIO struct{}
   134  
   135  func (fakeFileIO) Path(f fakeFile) string                { return f.name }
   136  func (fakeFileIO) Lstat(f fakeFile) (os.FileInfo, error) { return fakeFileInfo{f}, nil }
   137  func (fakeFileIO) Open(f fakeFile) (io.ReadCloser, error) {
   138  	if f.data != nil {
   139  		return io.NopCloser(bytes.NewReader(f.data)), nil
   140  	}
   141  	if f.size >= uint64(modzip.MaxZipFile<<1) {
   142  		return nil, fmt.Errorf("cannot open fakeFile of size %d", f.size)
   143  	}
   144  	return io.NopCloser(io.LimitReader(zeroReader{}, int64(f.size))), nil
   145  }
   146  
   147  type fakeFile struct {
   148  	name  string
   149  	isDir bool
   150  	size  uint64
   151  	data  []byte // if nil, Open will access a sequence of 0-bytes
   152  }
   153  
   154  type fakeFileInfo struct {
   155  	f fakeFile
   156  }
   157  
   158  func (fi fakeFileInfo) Name() string {
   159  	return path.Base(fi.f.name)
   160  }
   161  
   162  func (fi fakeFileInfo) Size() int64 {
   163  	if fi.f.size == 0 {
   164  		return int64(len(fi.f.data))
   165  	}
   166  	return int64(fi.f.size)
   167  }
   168  func (fi fakeFileInfo) Mode() os.FileMode {
   169  	if fi.f.isDir {
   170  		return os.ModeDir | 0o755
   171  	}
   172  	return 0o644
   173  }
   174  
   175  func (fi fakeFileInfo) ModTime() time.Time { return time.Time{} }
   176  func (fi fakeFileInfo) IsDir() bool        { return fi.f.isDir }
   177  func (fi fakeFileInfo) Sys() interface{}   { return nil }
   178  
   179  type zeroReader struct{}
   180  
   181  func (r zeroReader) Read(b []byte) (int, error) {
   182  	for i := range b {
   183  		b[i] = 0
   184  	}
   185  	return len(b), nil
   186  }
   187  
   188  func formatCheckedFiles(cf modzip.CheckedFiles) string {
   189  	buf := &bytes.Buffer{}
   190  	fmt.Fprintf(buf, "valid:\n")
   191  	for _, f := range cf.Valid {
   192  		fmt.Fprintln(buf, f)
   193  	}
   194  	fmt.Fprintf(buf, "\nomitted:\n")
   195  	for _, f := range cf.Omitted {
   196  		fmt.Fprintf(buf, "%s: %v\n", f.Path, f.Err)
   197  	}
   198  	fmt.Fprintf(buf, "\ninvalid:\n")
   199  	for _, f := range cf.Invalid {
   200  		fmt.Fprintf(buf, "%s: %v\n", f.Path, f.Err)
   201  	}
   202  	return buf.String()
   203  }
   204  
   205  func TestCheckFilesWithDirWithTrailingSlash(t *testing.T) {
   206  	t.Parallel()
   207  	// When checking a zip file,
   208  	files := []fakeFile{{
   209  		name:  "cue.mod/",
   210  		isDir: true,
   211  	}, {
   212  		name: "cue.mod/module.cue",
   213  		data: []byte(`module: "example.com/m"`),
   214  	}}
   215  	_, err := modzip.CheckFiles[fakeFile](files, fakeFileIO{})
   216  	if err != nil {
   217  		t.Fatal(err)
   218  	}
   219  }
   220  
   221  // TestCheckFiles verifies behavior of CheckFiles. Note that CheckFiles is also
   222  // covered by TestCreate, TestCreateDir, and TestCreateSizeLimits, so this test
   223  // focuses on how multiple errors and omissions are reported, rather than trying
   224  // to cover every case.
   225  func TestCheckFiles(t *testing.T) {
   226  	t.Parallel()
   227  	testPaths, err := filepath.Glob(filepath.FromSlash("testdata/check_files/*.txt"))
   228  	if err != nil {
   229  		t.Fatal(err)
   230  	}
   231  	for _, testPath := range testPaths {
   232  		testPath := testPath
   233  		name := strings.TrimSuffix(filepath.Base(testPath), ".txt")
   234  		t.Run(name, func(t *testing.T) {
   235  			t.Parallel()
   236  
   237  			// Load the test.
   238  			test, err := readTest(testPath)
   239  			if err != nil {
   240  				t.Fatal(err)
   241  			}
   242  			t.Logf("test file %s", testPath)
   243  			files := make([]fakeFile, 0, len(test.archive.Files))
   244  			for _, tf := range test.archive.Files {
   245  				files = append(files, fakeFile{
   246  					name: tf.Name,
   247  					size: uint64(len(tf.Data)),
   248  					data: tf.Data,
   249  				})
   250  			}
   251  
   252  			// Check the files.
   253  			cf, _ := modzip.CheckFiles[fakeFile](files, fakeFileIO{})
   254  			got := formatCheckedFiles(cf)
   255  			if diff := cmp.Diff(test.want, got); diff != "" {
   256  				t.Errorf("unexpected result; (-want +got):\n%s", diff)
   257  			}
   258  			// Check that the error (if any) is just a list of invalid files.
   259  			// SizeError is not covered in this test.
   260  			var gotErr string
   261  			wantErr := test.wantErr
   262  			if wantErr == "" && len(cf.Invalid) > 0 {
   263  				wantErr = modzip.FileErrorList(cf.Invalid).Error()
   264  			}
   265  			if err := cf.Err(); err != nil {
   266  				gotErr = err.Error()
   267  			}
   268  			if gotErr != wantErr {
   269  				t.Errorf("got error:\n%s\n\nwant error:\n%s", gotErr, wantErr)
   270  			}
   271  		})
   272  	}
   273  }
   274  
   275  // TestCheckDir verifies behavior of the CheckDir function. Note that CheckDir
   276  // relies on CheckFiles and listFilesInDir (called by CreateFromDir), so this
   277  // test focuses on how multiple errors and omissions are reported, rather than
   278  // trying to cover every case.
   279  func TestCheckDir(t *testing.T) {
   280  	t.Parallel()
   281  	testPaths, err := filepath.Glob(filepath.FromSlash("testdata/check_dir/*.txt"))
   282  	if err != nil {
   283  		t.Fatal(err)
   284  	}
   285  	for _, testPath := range testPaths {
   286  		testPath := testPath
   287  		name := strings.TrimSuffix(filepath.Base(testPath), ".txt")
   288  		t.Run(name, func(t *testing.T) {
   289  			t.Parallel()
   290  
   291  			// Load the test and extract the files to a temporary directory.
   292  			test, err := readTest(testPath)
   293  			if err != nil {
   294  				t.Fatal(err)
   295  			}
   296  			t.Logf("test file %s", testPath)
   297  			tmpDir, err := extractTxtarToTempDir(t, test.archive)
   298  			if err != nil {
   299  				t.Fatal(err)
   300  			}
   301  
   302  			// Check the directory.
   303  			cf, _ := modzip.CheckDir(tmpDir)
   304  			rep := strings.NewReplacer(tmpDir, "$work", `'\''`, `'\''`, string(os.PathSeparator), "/")
   305  			got := rep.Replace(formatCheckedFiles(cf))
   306  			if diff := cmp.Diff(test.want, got); diff != "" {
   307  				t.Errorf("unexpected result; (-want +got):\n%s", diff)
   308  			}
   309  
   310  			// Check that the error (if any) is just a list of invalid files.
   311  			// SizeError is not covered in this test.
   312  			var gotErr string
   313  			wantErr := test.wantErr
   314  			if wantErr == "" && len(cf.Invalid) > 0 {
   315  				wantErr = modzip.FileErrorList(cf.Invalid).Error()
   316  			}
   317  			if err := cf.Err(); err != nil {
   318  				gotErr = err.Error()
   319  			}
   320  			if gotErr != wantErr {
   321  				t.Errorf("got error:\n%s\n\nwant error:\n%s", gotErr, wantErr)
   322  			}
   323  		})
   324  	}
   325  }
   326  
   327  // TestCheckZip verifies behavior of CheckZip. Note that CheckZip is also
   328  // covered by TestUnzip, so this test focuses on how multiple errors are
   329  // reported, rather than trying to cover every case.
   330  func TestCheckZip(t *testing.T) {
   331  	t.Parallel()
   332  	testPaths, err := filepath.Glob(filepath.FromSlash("testdata/check_zip/*.txt"))
   333  	if err != nil {
   334  		t.Fatal(err)
   335  	}
   336  	for _, testPath := range testPaths {
   337  		testPath := testPath
   338  		name := strings.TrimSuffix(filepath.Base(testPath), ".txt")
   339  		t.Run(name, func(t *testing.T) {
   340  			t.Parallel()
   341  
   342  			// Load the test and extract the files to a temporary zip file.
   343  			test, err := readTest(testPath)
   344  			if err != nil {
   345  				t.Fatal(err)
   346  			}
   347  			t.Logf("test file %s", testPath)
   348  			tmpZipPath, err := extractTxtarToTempZip(t, test.archive)
   349  			if err != nil {
   350  				t.Fatal(err)
   351  			}
   352  
   353  			// Check the zip.
   354  			m := module.MustNewVersion(test.path, test.version)
   355  			cf, checkZipErr := modzip.CheckZipFile(m, tmpZipPath)
   356  			got := formatCheckedFiles(cf)
   357  			if diff := cmp.Diff(test.want, got); diff != "" {
   358  				t.Errorf("unexpected result; (-want +got):\n%s", diff)
   359  			}
   360  
   361  			// Check that the error (if any) is just a list of invalid files.
   362  			// SizeError is not covered in this test.
   363  			var gotErr string
   364  			wantErr := test.wantErr
   365  			if wantErr == "" && len(cf.Invalid) > 0 {
   366  				wantErr = modzip.FileErrorList(cf.Invalid).Error()
   367  			}
   368  			if checkZipErr != nil {
   369  				gotErr = checkZipErr.Error()
   370  			} else if err := cf.Err(); err != nil {
   371  				gotErr = err.Error()
   372  			}
   373  			if gotErr != wantErr {
   374  				t.Errorf("got error:\n%s\n\nwant error:\n%s", gotErr, wantErr)
   375  			}
   376  		})
   377  	}
   378  }
   379  
   380  func TestCreate(t *testing.T) {
   381  	t.Parallel()
   382  	testDir := filepath.FromSlash("testdata/create")
   383  	testInfos, err := os.ReadDir(testDir)
   384  	if err != nil {
   385  		t.Fatal(err)
   386  	}
   387  	for _, testInfo := range testInfos {
   388  		testInfo := testInfo
   389  		base := filepath.Base(testInfo.Name())
   390  		if filepath.Ext(base) != ".txt" {
   391  			continue
   392  		}
   393  		t.Run(base[:len(base)-len(".txt")], func(t *testing.T) {
   394  			t.Parallel()
   395  
   396  			// Load the test.
   397  			testPath := filepath.Join(testDir, testInfo.Name())
   398  			test, err := readTest(testPath)
   399  			if err != nil {
   400  				t.Fatal(err)
   401  			}
   402  			t.Logf("test file: %s", testPath)
   403  
   404  			// Write zip to temporary file.
   405  			tmpZipFile := tempFile(t, "tmp.zip")
   406  			m := module.MustNewVersion(test.path, test.version)
   407  			files := make([]fakeFile, len(test.archive.Files))
   408  			for i, tf := range test.archive.Files {
   409  				files[i] = fakeFile{
   410  					name: tf.Name,
   411  					size: uint64(len(tf.Data)),
   412  					data: tf.Data,
   413  				}
   414  			}
   415  			if err := modzip.Create[fakeFile](tmpZipFile, m, files, fakeFileIO{}); err != nil {
   416  				if test.wantErr == "" {
   417  					t.Fatalf("unexpected error: %v", err)
   418  				} else if !strings.Contains(err.Error(), test.wantErr) {
   419  					t.Fatalf("got error %q; want error containing %q", err.Error(), test.wantErr)
   420  				} else {
   421  					return
   422  				}
   423  			} else if test.wantErr != "" {
   424  				t.Fatalf("unexpected success; wanted error containing %q", test.wantErr)
   425  			}
   426  			if err := tmpZipFile.Close(); err != nil {
   427  				t.Fatal(err)
   428  			}
   429  
   430  			// Hash zip file, compare with known value.
   431  			if hash, err := dirhash.HashZip(tmpZipFile.Name(), dirhash.Hash1); err != nil {
   432  				t.Fatal(err)
   433  			} else if hash != test.hash {
   434  				t.Errorf("got hash: %q\nwant: %q", hash, test.hash)
   435  			}
   436  			assertNoExcludedFiles(t, tmpZipFile.Name())
   437  		})
   438  	}
   439  }
   440  
   441  func assertNoExcludedFiles(t *testing.T, zf string) {
   442  	z, err := zip.OpenReader(zf)
   443  	if err != nil {
   444  		t.Fatal(err)
   445  	}
   446  	defer z.Close()
   447  	for _, f := range z.File {
   448  		if shouldExclude(f) {
   449  			t.Errorf("file %s should have been excluded but was not", f.Name)
   450  		}
   451  	}
   452  }
   453  
   454  func shouldExclude(f *zip.File) bool {
   455  	r, err := f.Open()
   456  	if err != nil {
   457  		panic(err)
   458  	}
   459  	defer r.Close()
   460  	data, err := io.ReadAll(r)
   461  	if err != nil {
   462  		panic(err)
   463  	}
   464  	return bytes.Contains(data, []byte("excluded"))
   465  }
   466  
   467  func TestCreateFromDir(t *testing.T) {
   468  	t.Parallel()
   469  	testDir := filepath.FromSlash("testdata/create_from_dir")
   470  	testInfos, err := os.ReadDir(testDir)
   471  	if err != nil {
   472  		t.Fatal(err)
   473  	}
   474  	for _, testInfo := range testInfos {
   475  		testInfo := testInfo
   476  		base := filepath.Base(testInfo.Name())
   477  		if filepath.Ext(base) != ".txt" {
   478  			continue
   479  		}
   480  		t.Run(base[:len(base)-len(".txt")], func(t *testing.T) {
   481  			t.Parallel()
   482  
   483  			// Load the test.
   484  			testPath := filepath.Join(testDir, testInfo.Name())
   485  			test, err := readTest(testPath)
   486  			if err != nil {
   487  				t.Fatal(err)
   488  			}
   489  			t.Logf("test file %s", testPath)
   490  
   491  			// Write files to a temporary directory.
   492  			tmpDir, err := extractTxtarToTempDir(t, test.archive)
   493  			if err != nil {
   494  				t.Fatal(err)
   495  			}
   496  
   497  			// Create zip from the directory.
   498  			tmpZipFile := tempFile(t, "tmp.zip")
   499  			m := module.MustNewVersion(test.path, test.version)
   500  			if err := modzip.CreateFromDir(tmpZipFile, m, tmpDir); err != nil {
   501  				if test.wantErr == "" {
   502  					t.Fatalf("unexpected error: %v", err)
   503  				} else if !strings.Contains(err.Error(), test.wantErr) {
   504  					t.Fatalf("got error %q; want error containing %q", err, test.wantErr)
   505  				} else {
   506  					return
   507  				}
   508  			} else if test.wantErr != "" {
   509  				t.Fatalf("unexpected success; want error containing %q", test.wantErr)
   510  			}
   511  
   512  			// Hash zip file, compare with known value.
   513  			if hash, err := dirhash.HashZip(tmpZipFile.Name(), dirhash.Hash1); err != nil {
   514  				t.Fatal(err)
   515  			} else if hash != test.hash {
   516  				t.Fatalf("got hash: %q\nwant: %q", hash, test.hash)
   517  			}
   518  			assertNoExcludedFiles(t, tmpZipFile.Name())
   519  		})
   520  	}
   521  }
   522  
   523  func TestCreateFromDirSpecial(t *testing.T) {
   524  	t.Parallel()
   525  	for _, test := range []struct {
   526  		desc     string
   527  		setup    func(t *testing.T, tmpDir string) string
   528  		wantHash string
   529  	}{
   530  		{
   531  			desc: "ignore_empty_dir",
   532  			setup: func(t *testing.T, tmpDir string) string {
   533  				if err := os.Mkdir(filepath.Join(tmpDir, "empty"), 0777); err != nil {
   534  					t.Fatal(err)
   535  				}
   536  				mustWriteFile(
   537  					filepath.Join(tmpDir, "cue.mod/module.cue"),
   538  					`module: "example.com/m"`,
   539  				)
   540  				return tmpDir
   541  			},
   542  			wantHash: "h1:vEUjl4tTsFcZJC/Ed/Rph2nVDCMG7OFC4wrQDfxF3n0=",
   543  		}, {
   544  			desc: "ignore_symlink",
   545  			setup: func(t *testing.T, tmpDir string) string {
   546  				if err := os.Symlink(tmpDir, filepath.Join(tmpDir, "link")); err != nil {
   547  					switch runtime.GOOS {
   548  					case "plan9", "windows":
   549  						t.Skipf("could not create symlink: %v", err)
   550  					default:
   551  						t.Fatal(err)
   552  					}
   553  				}
   554  				mustWriteFile(
   555  					filepath.Join(tmpDir, "cue.mod/module.cue"),
   556  					`module: "example.com/m"`,
   557  				)
   558  				return tmpDir
   559  			},
   560  			wantHash: "h1:vEUjl4tTsFcZJC/Ed/Rph2nVDCMG7OFC4wrQDfxF3n0=",
   561  		}, {
   562  			desc: "dir_is_vendor",
   563  			setup: func(t *testing.T, tmpDir string) string {
   564  				vendorDir := filepath.Join(tmpDir, "vendor")
   565  				if err := os.Mkdir(vendorDir, 0777); err != nil {
   566  					t.Fatal(err)
   567  				}
   568  				mustWriteFile(
   569  					filepath.Join(vendorDir, "cue.mod/module.cue"),
   570  					`module: "example.com/m"`,
   571  				)
   572  				return vendorDir
   573  			},
   574  			wantHash: "h1:vEUjl4tTsFcZJC/Ed/Rph2nVDCMG7OFC4wrQDfxF3n0=",
   575  		},
   576  	} {
   577  		t.Run(test.desc, func(t *testing.T) {
   578  			tmpDir := t.TempDir()
   579  			dir := test.setup(t, tmpDir)
   580  
   581  			tmpZipFile := tempFile(t, "tmp.zip")
   582  			m := module.MustNewVersion("example.com/m@v1", "v1.0.0")
   583  
   584  			if err := modzip.CreateFromDir(tmpZipFile, m, dir); err != nil {
   585  				t.Fatal(err)
   586  			}
   587  			if err := tmpZipFile.Close(); err != nil {
   588  				t.Fatal(err)
   589  			}
   590  
   591  			if hash, err := dirhash.HashZip(tmpZipFile.Name(), dirhash.Hash1); err != nil {
   592  				t.Fatal(err)
   593  			} else if hash != test.wantHash {
   594  				t.Fatalf("got hash %q; want %q", hash, test.wantHash)
   595  			}
   596  		})
   597  	}
   598  }
   599  
   600  func TestUnzip(t *testing.T) {
   601  	t.Parallel()
   602  	testDir := filepath.FromSlash("testdata/unzip")
   603  	testInfos, err := os.ReadDir(testDir)
   604  	if err != nil {
   605  		t.Fatal(err)
   606  	}
   607  	for _, testInfo := range testInfos {
   608  		base := filepath.Base(testInfo.Name())
   609  		if filepath.Ext(base) != ".txt" {
   610  			continue
   611  		}
   612  		t.Run(base[:len(base)-len(".txt")], func(t *testing.T) {
   613  			// Load the test.
   614  			testPath := filepath.Join(testDir, testInfo.Name())
   615  			test, err := readTest(testPath)
   616  			if err != nil {
   617  				t.Fatal(err)
   618  			}
   619  			t.Logf("test file %s", testPath)
   620  
   621  			// Convert txtar to temporary zip file.
   622  			tmpZipPath, err := extractTxtarToTempZip(t, test.archive)
   623  			if err != nil {
   624  				t.Fatal(err)
   625  			}
   626  
   627  			// Extract to a temporary directory.
   628  			tmpDir := t.TempDir()
   629  			m := module.MustNewVersion(test.path, test.version)
   630  			if err := modzip.Unzip(tmpDir, m, tmpZipPath); err != nil {
   631  				if test.wantErr == "" {
   632  					t.Fatalf("unexpected error: %v", err)
   633  				} else if !strings.Contains(err.Error(), test.wantErr) {
   634  					t.Fatalf("got error %q; want error containing %q", err.Error(), test.wantErr)
   635  				} else {
   636  					return
   637  				}
   638  			} else if test.wantErr != "" {
   639  				t.Fatalf("unexpected success; wanted error containing %q", test.wantErr)
   640  			}
   641  
   642  			// Hash the directory, compare to known value.
   643  			if hash, err := dirhash.HashDir(tmpDir, "", dirhash.Hash1); err != nil {
   644  				t.Fatal(err)
   645  			} else if hash != test.hash {
   646  				t.Fatalf("got hash %q\nwant: %q", hash, test.hash)
   647  			}
   648  		})
   649  	}
   650  }
   651  
   652  type sizeLimitTest struct {
   653  	desc              string
   654  	files             []fakeFile
   655  	wantErr           string
   656  	wantCheckFilesErr string
   657  	wantCreateErr     string
   658  	wantCheckZipErr   string
   659  	wantUnzipErr      string
   660  }
   661  
   662  // sizeLimitTests is shared by TestCreateSizeLimits and TestUnzipSizeLimits.
   663  var sizeLimitTests = [...]sizeLimitTest{
   664  	{
   665  		desc: "total_large",
   666  		files: []fakeFile{{
   667  			name: "large.go",
   668  			size: modzip.MaxZipFile - uint64(len(`module: "example.com/m@v1"`)),
   669  		}, {
   670  			name: "cue.mod/module.cue",
   671  			data: []byte(`module: "example.com/m@v1"`),
   672  		}},
   673  	}, {
   674  		desc: "total_too_large",
   675  		files: []fakeFile{{
   676  			name: "large.go",
   677  			size: modzip.MaxZipFile - uint64(len(`module: "example.com/m@v1"`)) + 1,
   678  		}, {
   679  			name: "cue.mod/module.cue",
   680  			data: []byte(`module: "example.com/m@v1"`),
   681  		}},
   682  		wantCheckFilesErr: "module source tree too large",
   683  		wantCreateErr:     "module source tree too large",
   684  		wantCheckZipErr:   "total uncompressed size of module contents too large",
   685  		wantUnzipErr:      "total uncompressed size of module contents too large",
   686  	}, {
   687  		desc: "large_cuemod",
   688  		files: []fakeFile{{
   689  			name: "cue.mod/module.cue",
   690  			size: modzip.MaxCUEMod,
   691  		}},
   692  	}, {
   693  		desc: "too_large_cuemod",
   694  		files: []fakeFile{{
   695  			name: "cue.mod/module.cue",
   696  			size: modzip.MaxCUEMod + 1,
   697  		}},
   698  		wantErr: "cue.mod/module.cue file too large",
   699  	}, {
   700  		desc: "large_license",
   701  		files: []fakeFile{{
   702  			name: "LICENSE",
   703  			size: modzip.MaxLICENSE,
   704  		}, {
   705  			name: "cue.mod/module.cue",
   706  			data: []byte(`module: "example.com/m@v1"`),
   707  		}},
   708  	}, {
   709  		desc: "too_large_license",
   710  		files: []fakeFile{{
   711  			name: "LICENSE",
   712  			size: modzip.MaxLICENSE + 1,
   713  		}, {
   714  			name: "cue.mod/module.cue",
   715  			data: []byte(`module: "example.com/m@v1"`),
   716  		}},
   717  		wantErr: "LICENSE file too large",
   718  	},
   719  }
   720  
   721  var sizeLimitVersion = module.MustNewVersion("example.com/large@v1", "v1.0.0")
   722  
   723  func TestCreateSizeLimits(t *testing.T) {
   724  	if testing.Short() || cuetest.RaceEnabled {
   725  		t.Skip("creating large files takes time")
   726  	}
   727  	t.Parallel()
   728  	tests := append(sizeLimitTests[:], sizeLimitTest{
   729  		// negative file size may happen when size is represented as uint64
   730  		// but is cast to int64, as is the case in zip files.
   731  		desc: "negative",
   732  		files: []fakeFile{{
   733  			name: "neg.go",
   734  			size: 0x8000000000000000,
   735  		}, {
   736  			name: "cue.mod/module.cue",
   737  			data: []byte(`module: "example.com/m@v1"`),
   738  		}},
   739  		wantErr: "module source tree too large",
   740  	}, sizeLimitTest{
   741  		desc: "size_is_a_lie",
   742  		files: []fakeFile{{
   743  			name: "lie.go",
   744  			size: 1,
   745  			data: []byte(`package large`),
   746  		}, {
   747  			name: "cue.mod/module.cue",
   748  			data: []byte(`module: "example.com/m@v1"`),
   749  		}},
   750  		wantCreateErr: "larger than declared size",
   751  	})
   752  
   753  	for _, test := range tests {
   754  		test := test
   755  		t.Run(test.desc, func(t *testing.T) {
   756  			t.Parallel()
   757  
   758  			wantCheckFilesErr := test.wantCheckFilesErr
   759  			if wantCheckFilesErr == "" {
   760  				wantCheckFilesErr = test.wantErr
   761  			}
   762  			if _, err := modzip.CheckFiles[fakeFile](test.files, fakeFileIO{}); err == nil && wantCheckFilesErr != "" {
   763  				t.Fatalf("CheckFiles: unexpected success; want error containing %q", wantCheckFilesErr)
   764  			} else if err != nil && wantCheckFilesErr == "" {
   765  				t.Fatalf("CheckFiles: got error %q; want success", err)
   766  			} else if err != nil && !strings.Contains(err.Error(), wantCheckFilesErr) {
   767  				t.Fatalf("CheckFiles: got error %q; want error containing %q", err, wantCheckFilesErr)
   768  			}
   769  
   770  			wantCreateErr := test.wantCreateErr
   771  			if wantCreateErr == "" {
   772  				wantCreateErr = test.wantErr
   773  			}
   774  			if err := modzip.Create[fakeFile](io.Discard, sizeLimitVersion, test.files, fakeFileIO{}); err == nil && wantCreateErr != "" {
   775  				t.Fatalf("Create: unexpected success; want error containing %q", wantCreateErr)
   776  			} else if err != nil && wantCreateErr == "" {
   777  				t.Fatalf("Create: got error %q; want success", err)
   778  			} else if err != nil && !strings.Contains(err.Error(), wantCreateErr) {
   779  				t.Fatalf("Create: got error %q; want error containing %q", err, wantCreateErr)
   780  			}
   781  		})
   782  	}
   783  }
   784  
   785  func TestUnzipSizeLimits(t *testing.T) {
   786  	if testing.Short() || cuetest.RaceEnabled {
   787  		t.Skip("creating large files takes time")
   788  	}
   789  	t.Parallel()
   790  	for _, test := range sizeLimitTests {
   791  		test := test
   792  		t.Run(test.desc, func(t *testing.T) {
   793  			t.Parallel()
   794  			tmpZipFile := tempFile(t, "tmp.zip")
   795  
   796  			zw := zip.NewWriter(tmpZipFile)
   797  			for _, tf := range test.files {
   798  				zf, err := zw.Create(fakeFileIO{}.Path(tf))
   799  				if err != nil {
   800  					t.Fatal(err)
   801  				}
   802  				rc, err := fakeFileIO{}.Open(tf)
   803  				if err != nil {
   804  					t.Fatal(err)
   805  				}
   806  				_, err = io.Copy(zf, rc)
   807  				rc.Close()
   808  				if err != nil {
   809  					t.Fatal(err)
   810  				}
   811  			}
   812  			if err := zw.Close(); err != nil {
   813  				t.Fatal(err)
   814  			}
   815  			if err := tmpZipFile.Close(); err != nil {
   816  				t.Fatal(err)
   817  			}
   818  
   819  			tmpDir := t.TempDir()
   820  
   821  			wantCheckZipErr := test.wantCheckZipErr
   822  			if wantCheckZipErr == "" {
   823  				wantCheckZipErr = test.wantErr
   824  			}
   825  			cf, err := modzip.CheckZipFile(sizeLimitVersion, tmpZipFile.Name())
   826  			if err == nil {
   827  				err = cf.Err()
   828  			}
   829  			if err == nil && wantCheckZipErr != "" {
   830  				t.Fatalf("CheckZip: unexpected success; want error containing %q", wantCheckZipErr)
   831  			} else if err != nil && wantCheckZipErr == "" {
   832  				t.Fatalf("CheckZip: got error %q; want success", err)
   833  			} else if err != nil && !strings.Contains(err.Error(), wantCheckZipErr) {
   834  				t.Fatalf("CheckZip: got error %q; want error containing %q", err, wantCheckZipErr)
   835  			}
   836  
   837  			wantUnzipErr := test.wantUnzipErr
   838  			if wantUnzipErr == "" {
   839  				wantUnzipErr = test.wantErr
   840  			}
   841  			if err := modzip.Unzip(tmpDir, sizeLimitVersion, tmpZipFile.Name()); err == nil && wantUnzipErr != "" {
   842  				t.Fatalf("Unzip: unexpected success; want error containing %q", wantUnzipErr)
   843  			} else if err != nil && wantUnzipErr == "" {
   844  				t.Fatalf("Unzip: got error %q; want success", err)
   845  			} else if err != nil && !strings.Contains(err.Error(), wantUnzipErr) {
   846  				t.Fatalf("Unzip: got error %q; want error containing %q", err, wantUnzipErr)
   847  			}
   848  		})
   849  	}
   850  }
   851  
   852  func TestUnzipSizeLimitsSpecial(t *testing.T) {
   853  	if testing.Short() || cuetest.RaceEnabled {
   854  		t.Skip("skipping test; creating large files takes time")
   855  	}
   856  
   857  	t.Parallel()
   858  	for _, test := range []struct {
   859  		desc     string
   860  		wantErr  string
   861  		m        module.Version
   862  		writeZip func(t *testing.T, zipFile *os.File)
   863  	}{
   864  		{
   865  			desc: "large_zip",
   866  			m:    module.MustNewVersion("example.com/m@v1", "v1.0.0"),
   867  			writeZip: func(t *testing.T, zipFile *os.File) {
   868  				if err := zipFile.Truncate(modzip.MaxZipFile); err != nil {
   869  					t.Fatal(err)
   870  				}
   871  			},
   872  			// this is not an error we care about; we're just testing whether
   873  			// Unzip checks the size of the file before opening.
   874  			// It's harder to create a valid zip file of exactly the right size.
   875  			wantErr: "not a valid zip file",
   876  		}, {
   877  			desc: "too_large_zip",
   878  			m:    module.MustNewVersion("example.com/m@v1", "v1.0.0"),
   879  			writeZip: func(t *testing.T, zipFile *os.File) {
   880  				if err := zipFile.Truncate(modzip.MaxZipFile + 1); err != nil {
   881  					t.Fatal(err)
   882  				}
   883  			},
   884  			wantErr: "module zip file is too large",
   885  		}, {
   886  			desc: "size_is_a_lie",
   887  			m:    module.MustNewVersion("example.com/m@v1", "v1.0.0"),
   888  			writeZip: func(t *testing.T, zipFile *os.File) {
   889  				// Create a normal zip file in memory containing one file full of zero
   890  				// bytes. Use a distinctive size so we can find it later.
   891  				zipBuf := &bytes.Buffer{}
   892  				zw := zip.NewWriter(zipBuf)
   893  				f, err := zw.Create("cue.mod/module.cue")
   894  				if err != nil {
   895  					t.Fatal(err)
   896  				}
   897  				realSize := 0x0BAD
   898  				buf := make([]byte, realSize)
   899  				if _, err := f.Write(buf); err != nil {
   900  					t.Fatal(err)
   901  				}
   902  				if err := zw.Close(); err != nil {
   903  					t.Fatal(err)
   904  				}
   905  
   906  				// Replace the uncompressed size of the file. As a shortcut, we just
   907  				// search-and-replace the byte sequence. It should occur twice because
   908  				// the 32- and 64-byte sizes are stored separately. All multi-byte
   909  				// values are little-endian.
   910  				zipData := zipBuf.Bytes()
   911  				realSizeData := []byte{0xAD, 0x0B}
   912  				fakeSizeData := []byte{0xAC, 0x00}
   913  				s := zipData
   914  				n := 0
   915  				for {
   916  					if i := bytes.Index(s, realSizeData); i < 0 {
   917  						break
   918  					} else {
   919  						s = s[i:]
   920  					}
   921  					copy(s[:len(fakeSizeData)], fakeSizeData)
   922  					n++
   923  				}
   924  				if n != 2 {
   925  					t.Fatalf("replaced size %d times; expected 2", n)
   926  				}
   927  
   928  				// Write the modified zip to the actual file.
   929  				if _, err := zipFile.Write(zipData); err != nil {
   930  					t.Fatal(err)
   931  				}
   932  			},
   933  			wantErr: "not a valid zip file",
   934  		},
   935  	} {
   936  		test := test
   937  		t.Run(test.desc, func(t *testing.T) {
   938  			t.Parallel()
   939  
   940  			tmpZipFile := tempFile(t, "tmp.zip")
   941  			test.writeZip(t, tmpZipFile)
   942  			if err := tmpZipFile.Close(); err != nil {
   943  				t.Fatal(err)
   944  			}
   945  
   946  			tmpDir := t.TempDir()
   947  
   948  			if err := modzip.Unzip(tmpDir, test.m, tmpZipFile.Name()); err == nil && test.wantErr != "" {
   949  				t.Fatalf("unexpected success; want error containing %q", test.wantErr)
   950  			} else if err != nil && test.wantErr == "" {
   951  				t.Fatalf("got error %q; want success", err)
   952  			} else if err != nil && !strings.Contains(err.Error(), test.wantErr) {
   953  				t.Fatalf("got error %q; want error containing %q", err, test.wantErr)
   954  			}
   955  		})
   956  	}
   957  }
   958  
   959  func mustWriteFile(name string, content string) {
   960  	if err := os.MkdirAll(filepath.Dir(name), 0o777); err != nil {
   961  		panic(err)
   962  	}
   963  	if err := os.WriteFile(name, []byte(content), 0o666); err != nil {
   964  		panic(err)
   965  	}
   966  }
   967  
   968  func tempFile(t *testing.T, name string) *os.File {
   969  	f, err := os.Create(filepath.Join(t.TempDir(), name))
   970  	if err != nil {
   971  		t.Fatal(err)
   972  	}
   973  	t.Cleanup(func() { f.Close() })
   974  	return f
   975  }
   976  

View as plain text