...

Source file src/github.com/cyphar/filepath-securejoin/join_test.go

Documentation: github.com/cyphar/filepath-securejoin

     1  // Copyright (C) 2017 SUSE LLC. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package securejoin
     6  
     7  import (
     8  	"errors"
     9  	"io/ioutil"
    10  	"os"
    11  	"path/filepath"
    12  	"runtime"
    13  	"syscall"
    14  	"testing"
    15  )
    16  
    17  // TODO: These tests won't work on plan9 because it doesn't have symlinks, and
    18  //       also we use '/' here explicitly which probably won't work on Windows.
    19  
    20  func symlink(t *testing.T, oldname, newname string) {
    21  	if err := os.Symlink(oldname, newname); err != nil {
    22  		t.Fatal(err)
    23  	}
    24  }
    25  
    26  type input struct {
    27  	root, unsafe string
    28  	expected     string
    29  }
    30  
    31  // Test basic handling of symlink expansion.
    32  func TestSymlink(t *testing.T) {
    33  	dir, err := ioutil.TempDir("", "TestSymlink")
    34  	if err != nil {
    35  		t.Fatal(err)
    36  	}
    37  	dir, err = filepath.EvalSymlinks(dir)
    38  	if err != nil {
    39  		t.Fatal(err)
    40  	}
    41  	defer os.RemoveAll(dir)
    42  
    43  	symlink(t, "somepath", filepath.Join(dir, "etc"))
    44  	symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink"))
    45  	symlink(t, "/../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "passwd"))
    46  
    47  	rootOrVol := string(filepath.Separator)
    48  	if vol := filepath.VolumeName(dir); vol != "" {
    49  		rootOrVol = vol + rootOrVol
    50  	}
    51  
    52  	tc := []input{
    53  		// Make sure that expansion with a root of '/' proceeds in the expected fashion.
    54  		{rootOrVol, filepath.Join(dir, "passwd"), filepath.Join(rootOrVol, "etc", "passwd")},
    55  		{rootOrVol, filepath.Join(dir, "etclink"), filepath.Join(rootOrVol, "etc")},
    56  
    57  		{rootOrVol, filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")},
    58  		// Now test scoped expansion.
    59  		{dir, "passwd", filepath.Join(dir, "somepath", "passwd")},
    60  		{dir, "etclink", filepath.Join(dir, "somepath")},
    61  		{dir, "etc", filepath.Join(dir, "somepath")},
    62  		{dir, "etc/test", filepath.Join(dir, "somepath", "test")},
    63  		{dir, "etc/test/..", filepath.Join(dir, "somepath")},
    64  	}
    65  
    66  	for _, test := range tc {
    67  		got, err := SecureJoin(test.root, test.unsafe)
    68  		if err != nil {
    69  			t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
    70  			continue
    71  		}
    72  		// This is only for OS X, where /etc is a symlink to /private/etc. In
    73  		// principle, SecureJoin(/, pth) is the same as EvalSymlinks(pth) in
    74  		// the case where the path exists.
    75  		if test.root == "/" {
    76  			if expected, err := filepath.EvalSymlinks(test.expected); err == nil {
    77  				test.expected = expected
    78  			}
    79  		}
    80  		if got != test.expected {
    81  			t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
    82  			continue
    83  		}
    84  	}
    85  }
    86  
    87  // In a path without symlinks, SecureJoin is equivalent to Clean+Join.
    88  func TestNoSymlink(t *testing.T) {
    89  	dir, err := ioutil.TempDir("", "TestNoSymlink")
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  	dir, err = filepath.EvalSymlinks(dir)
    94  	if err != nil {
    95  		t.Fatal(err)
    96  	}
    97  	defer os.RemoveAll(dir)
    98  
    99  	tc := []input{
   100  		{dir, "somepath", filepath.Join(dir, "somepath")},
   101  		{dir, "even/more/path", filepath.Join(dir, "even", "more", "path")},
   102  		{dir, "/this/is/a/path", filepath.Join(dir, "this", "is", "a", "path")},
   103  		{dir, "also/a/../path/././/with/some/./.././junk", filepath.Join(dir, "also", "path", "with", "junk")},
   104  		{dir, "yetanother/../path/././/with/some/./.././junk../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "etc", "passwd")},
   105  		{dir, "/../../../../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "etc", "passwd")},
   106  		{dir, "../../../../../../../../../../../../../../../../somedir", filepath.Join(dir, "somedir")},
   107  		{dir, "../../../../../../../../../../../../../../../../", filepath.Join(dir)},
   108  		{dir, "./../../.././././../../../../../../../../../../../../../../../../etc passwd", filepath.Join(dir, "etc passwd")},
   109  	}
   110  
   111  	if runtime.GOOS == "windows" {
   112  		tc = append(tc, []input{
   113  			{dir, "d:\\etc\\test", filepath.Join(dir, "etc", "test")},
   114  		}...)
   115  	}
   116  
   117  	for _, test := range tc {
   118  		got, err := SecureJoin(test.root, test.unsafe)
   119  		if err != nil {
   120  			t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
   121  		}
   122  		if got != test.expected {
   123  			t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
   124  		}
   125  	}
   126  }
   127  
   128  // Make sure that .. is **not** expanded lexically.
   129  func TestNonLexical(t *testing.T) {
   130  	dir, err := ioutil.TempDir("", "TestNonLexical")
   131  	if err != nil {
   132  		t.Fatal(err)
   133  	}
   134  	dir, err = filepath.EvalSymlinks(dir)
   135  	if err != nil {
   136  		t.Fatal(err)
   137  	}
   138  	defer os.RemoveAll(dir)
   139  
   140  	os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
   141  	os.MkdirAll(filepath.Join(dir, "cousinparent", "cousin"), 0755)
   142  	symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link"))
   143  	symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2"))
   144  	symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3"))
   145  
   146  	for _, test := range []input{
   147  		{dir, "subdir", filepath.Join(dir, "subdir")},
   148  		{dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
   149  		{dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
   150  		{dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
   151  		{dir, "subdir/../test", filepath.Join(dir, "test")},
   152  		// This is the divergence from a simple filepath.Clean implementation.
   153  		{dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")},
   154  		{dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")},
   155  		{dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")},
   156  	} {
   157  		got, err := SecureJoin(test.root, test.unsafe)
   158  		if err != nil {
   159  			t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
   160  			continue
   161  		}
   162  		if got != test.expected {
   163  			t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
   164  			continue
   165  		}
   166  	}
   167  }
   168  
   169  // Make sure that symlink loops result in errors.
   170  func TestSymlinkLoop(t *testing.T) {
   171  	dir, err := ioutil.TempDir("", "TestSymlinkLoop")
   172  	if err != nil {
   173  		t.Fatal(err)
   174  	}
   175  	dir, err = filepath.EvalSymlinks(dir)
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	defer os.RemoveAll(dir)
   180  
   181  	os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
   182  	symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "subdir", "link"))
   183  	symlink(t, "/subdir/link", filepath.Join(dir, "path"))
   184  	symlink(t, "/../../../../../../../../../../../../../../../../self", filepath.Join(dir, "self"))
   185  
   186  	for _, test := range []struct {
   187  		root, unsafe string
   188  	}{
   189  		{dir, "subdir/link"},
   190  		{dir, "path"},
   191  		{dir, "../../path"},
   192  		{dir, "subdir/link/../.."},
   193  		{dir, "../../../../../../../../../../../../../../../../subdir/link/../../../../../../../../../../../../../../../.."},
   194  		{dir, "self"},
   195  		{dir, "self/.."},
   196  		{dir, "/../../../../../../../../../../../../../../../../self/.."},
   197  		{dir, "/self/././.."},
   198  	} {
   199  		got, err := SecureJoin(test.root, test.unsafe)
   200  		if !errors.Is(err, syscall.ELOOP) {
   201  			t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err)
   202  			continue
   203  		}
   204  	}
   205  }
   206  
   207  // Make sure that ENOTDIR is correctly handled.
   208  func TestEnotdir(t *testing.T) {
   209  	dir, err := ioutil.TempDir("", "TestEnotdir")
   210  	if err != nil {
   211  		t.Fatal(err)
   212  	}
   213  	dir, err = filepath.EvalSymlinks(dir)
   214  	if err != nil {
   215  		t.Fatal(err)
   216  	}
   217  	defer os.RemoveAll(dir)
   218  
   219  	os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
   220  	ioutil.WriteFile(filepath.Join(dir, "notdir"), []byte("I am not a directory!"), 0755)
   221  	symlink(t, "/../../../notdir/somechild", filepath.Join(dir, "subdir", "link"))
   222  
   223  	for _, test := range []struct {
   224  		root, unsafe string
   225  	}{
   226  		{dir, "subdir/link"},
   227  		{dir, "notdir"},
   228  		{dir, "notdir/child"},
   229  	} {
   230  		_, err := SecureJoin(test.root, test.unsafe)
   231  		if err != nil {
   232  			t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
   233  			continue
   234  		}
   235  	}
   236  }
   237  
   238  // Some silly tests to make sure that all error types are correctly handled.
   239  func TestIsNotExist(t *testing.T) {
   240  	for _, test := range []struct {
   241  		err      error
   242  		expected bool
   243  	}{
   244  		{&os.PathError{Op: "test1", Err: syscall.ENOENT}, true},
   245  		{&os.LinkError{Op: "test1", Err: syscall.ENOENT}, true},
   246  		{&os.SyscallError{Syscall: "test1", Err: syscall.ENOENT}, true},
   247  		{&os.PathError{Op: "test2", Err: syscall.ENOTDIR}, true},
   248  		{&os.LinkError{Op: "test2", Err: syscall.ENOTDIR}, true},
   249  		{&os.SyscallError{Syscall: "test2", Err: syscall.ENOTDIR}, true},
   250  		{&os.PathError{Op: "test3", Err: syscall.EACCES}, false},
   251  		{&os.LinkError{Op: "test3", Err: syscall.EACCES}, false},
   252  		{&os.SyscallError{Syscall: "test3", Err: syscall.EACCES}, false},
   253  		{errors.New("not a proper error"), false},
   254  	} {
   255  		got := IsNotExist(test.err)
   256  		if got != test.expected {
   257  			t.Errorf("IsNotExist(%#v): expected %v, got %v", test.err, test.expected, got)
   258  		}
   259  	}
   260  }
   261  
   262  type mockVFS struct {
   263  	lstat    func(path string) (os.FileInfo, error)
   264  	readlink func(path string) (string, error)
   265  }
   266  
   267  func (m mockVFS) Lstat(path string) (os.FileInfo, error) { return m.lstat(path) }
   268  func (m mockVFS) Readlink(path string) (string, error)   { return m.readlink(path) }
   269  
   270  // Make sure that SecureJoinVFS actually does use the given VFS interface.
   271  func TestSecureJoinVFS(t *testing.T) {
   272  	dir, err := ioutil.TempDir("", "TestNonLexical")
   273  	if err != nil {
   274  		t.Fatal(err)
   275  	}
   276  	dir, err = filepath.EvalSymlinks(dir)
   277  	if err != nil {
   278  		t.Fatal(err)
   279  	}
   280  	defer os.RemoveAll(dir)
   281  
   282  	os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
   283  	os.MkdirAll(filepath.Join(dir, "cousinparent", "cousin"), 0755)
   284  	symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link"))
   285  	symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2"))
   286  	symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3"))
   287  
   288  	for _, test := range []input{
   289  		{dir, "subdir", filepath.Join(dir, "subdir")},
   290  		{dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
   291  		{dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
   292  		{dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
   293  		{dir, "subdir/../test", filepath.Join(dir, "test")},
   294  		// This is the divergence from a simple filepath.Clean implementation.
   295  		{dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")},
   296  		{dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")},
   297  		{dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")},
   298  	} {
   299  		var nLstat, nReadlink int
   300  		mock := mockVFS{
   301  			lstat:    func(path string) (os.FileInfo, error) { nLstat++; return os.Lstat(path) },
   302  			readlink: func(path string) (string, error) { nReadlink++; return os.Readlink(path) },
   303  		}
   304  
   305  		got, err := SecureJoinVFS(test.root, test.unsafe, mock)
   306  		if err != nil {
   307  			t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
   308  			continue
   309  		}
   310  		if got != test.expected {
   311  			t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
   312  			continue
   313  		}
   314  		if nLstat == 0 && nReadlink == 0 {
   315  			t.Errorf("securejoin(%q, %q): expected to use either lstat or readlink, neither were used", test.root, test.unsafe)
   316  		}
   317  	}
   318  }
   319  
   320  // Make sure that SecureJoinVFS actually does use the given VFS interface, and
   321  // that errors are correctly propagated.
   322  func TestSecureJoinVFSErrors(t *testing.T) {
   323  	var (
   324  		lstatErr    = errors.New("lstat error")
   325  		readlinkErr = errors.New("readlink err")
   326  	)
   327  
   328  	// Set up directory.
   329  	dir, err := ioutil.TempDir("", "TestSecureJoinVFSErrors")
   330  	if err != nil {
   331  		t.Fatal(err)
   332  	}
   333  	dir, err = filepath.EvalSymlinks(dir)
   334  	if err != nil {
   335  		t.Fatal(err)
   336  	}
   337  	defer os.RemoveAll(dir)
   338  
   339  	// Make a link.
   340  	symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "link"))
   341  
   342  	// Define some fake mock functions.
   343  	lstatFailFn := func(path string) (os.FileInfo, error) { return nil, lstatErr }
   344  	readlinkFailFn := func(path string) (string, error) { return "", readlinkErr }
   345  
   346  	// Make sure that the set of {lstat, readlink} failures do propagate.
   347  	for idx, test := range []struct {
   348  		vfs      VFS
   349  		expected []error
   350  	}{
   351  		{
   352  			expected: []error{nil},
   353  			vfs: mockVFS{
   354  				lstat:    os.Lstat,
   355  				readlink: os.Readlink,
   356  			},
   357  		},
   358  		{
   359  			expected: []error{lstatErr},
   360  			vfs: mockVFS{
   361  				lstat:    lstatFailFn,
   362  				readlink: os.Readlink,
   363  			},
   364  		},
   365  		{
   366  			expected: []error{readlinkErr},
   367  			vfs: mockVFS{
   368  				lstat:    os.Lstat,
   369  				readlink: readlinkFailFn,
   370  			},
   371  		},
   372  		{
   373  			expected: []error{lstatErr, readlinkErr},
   374  			vfs: mockVFS{
   375  				lstat:    lstatFailFn,
   376  				readlink: readlinkFailFn,
   377  			},
   378  		},
   379  	} {
   380  		_, err := SecureJoinVFS(dir, "link", test.vfs)
   381  
   382  		success := false
   383  		for _, exp := range test.expected {
   384  			if err == exp {
   385  				success = true
   386  			}
   387  		}
   388  		if !success {
   389  			t.Errorf("SecureJoinVFS.mock%d: expected to get lstatError, got %v", idx, err)
   390  		}
   391  	}
   392  }
   393  

View as plain text