1
2
3
4
5 package securejoin
6
7 import (
8 "errors"
9 "io/ioutil"
10 "os"
11 "path/filepath"
12 "runtime"
13 "syscall"
14 "testing"
15 )
16
17
18
19
20 func symlink(t *testing.T, oldname, newname string) {
21 if err := os.Symlink(oldname, newname); err != nil {
22 t.Fatal(err)
23 }
24 }
25
26 type input struct {
27 root, unsafe string
28 expected string
29 }
30
31
32 func TestSymlink(t *testing.T) {
33 dir, err := ioutil.TempDir("", "TestSymlink")
34 if err != nil {
35 t.Fatal(err)
36 }
37 dir, err = filepath.EvalSymlinks(dir)
38 if err != nil {
39 t.Fatal(err)
40 }
41 defer os.RemoveAll(dir)
42
43 symlink(t, "somepath", filepath.Join(dir, "etc"))
44 symlink(t, "../../../../../../../../../../../../../etc", filepath.Join(dir, "etclink"))
45 symlink(t, "/../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "passwd"))
46
47 rootOrVol := string(filepath.Separator)
48 if vol := filepath.VolumeName(dir); vol != "" {
49 rootOrVol = vol + rootOrVol
50 }
51
52 tc := []input{
53
54 {rootOrVol, filepath.Join(dir, "passwd"), filepath.Join(rootOrVol, "etc", "passwd")},
55 {rootOrVol, filepath.Join(dir, "etclink"), filepath.Join(rootOrVol, "etc")},
56
57 {rootOrVol, filepath.Join(dir, "etc"), filepath.Join(dir, "somepath")},
58
59 {dir, "passwd", filepath.Join(dir, "somepath", "passwd")},
60 {dir, "etclink", filepath.Join(dir, "somepath")},
61 {dir, "etc", filepath.Join(dir, "somepath")},
62 {dir, "etc/test", filepath.Join(dir, "somepath", "test")},
63 {dir, "etc/test/..", filepath.Join(dir, "somepath")},
64 }
65
66 for _, test := range tc {
67 got, err := SecureJoin(test.root, test.unsafe)
68 if err != nil {
69 t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
70 continue
71 }
72
73
74
75 if test.root == "/" {
76 if expected, err := filepath.EvalSymlinks(test.expected); err == nil {
77 test.expected = expected
78 }
79 }
80 if got != test.expected {
81 t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
82 continue
83 }
84 }
85 }
86
87
88 func TestNoSymlink(t *testing.T) {
89 dir, err := ioutil.TempDir("", "TestNoSymlink")
90 if err != nil {
91 t.Fatal(err)
92 }
93 dir, err = filepath.EvalSymlinks(dir)
94 if err != nil {
95 t.Fatal(err)
96 }
97 defer os.RemoveAll(dir)
98
99 tc := []input{
100 {dir, "somepath", filepath.Join(dir, "somepath")},
101 {dir, "even/more/path", filepath.Join(dir, "even", "more", "path")},
102 {dir, "/this/is/a/path", filepath.Join(dir, "this", "is", "a", "path")},
103 {dir, "also/a/../path/././/with/some/./.././junk", filepath.Join(dir, "also", "path", "with", "junk")},
104 {dir, "yetanother/../path/././/with/some/./.././junk../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "etc", "passwd")},
105 {dir, "/../../../../../../../../../../../../../../../../etc/passwd", filepath.Join(dir, "etc", "passwd")},
106 {dir, "../../../../../../../../../../../../../../../../somedir", filepath.Join(dir, "somedir")},
107 {dir, "../../../../../../../../../../../../../../../../", filepath.Join(dir)},
108 {dir, "./../../.././././../../../../../../../../../../../../../../../../etc passwd", filepath.Join(dir, "etc passwd")},
109 }
110
111 if runtime.GOOS == "windows" {
112 tc = append(tc, []input{
113 {dir, "d:\\etc\\test", filepath.Join(dir, "etc", "test")},
114 }...)
115 }
116
117 for _, test := range tc {
118 got, err := SecureJoin(test.root, test.unsafe)
119 if err != nil {
120 t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
121 }
122 if got != test.expected {
123 t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
124 }
125 }
126 }
127
128
129 func TestNonLexical(t *testing.T) {
130 dir, err := ioutil.TempDir("", "TestNonLexical")
131 if err != nil {
132 t.Fatal(err)
133 }
134 dir, err = filepath.EvalSymlinks(dir)
135 if err != nil {
136 t.Fatal(err)
137 }
138 defer os.RemoveAll(dir)
139
140 os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
141 os.MkdirAll(filepath.Join(dir, "cousinparent", "cousin"), 0755)
142 symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link"))
143 symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2"))
144 symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3"))
145
146 for _, test := range []input{
147 {dir, "subdir", filepath.Join(dir, "subdir")},
148 {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
149 {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
150 {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
151 {dir, "subdir/../test", filepath.Join(dir, "test")},
152
153 {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")},
154 {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")},
155 {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")},
156 } {
157 got, err := SecureJoin(test.root, test.unsafe)
158 if err != nil {
159 t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
160 continue
161 }
162 if got != test.expected {
163 t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
164 continue
165 }
166 }
167 }
168
169
170 func TestSymlinkLoop(t *testing.T) {
171 dir, err := ioutil.TempDir("", "TestSymlinkLoop")
172 if err != nil {
173 t.Fatal(err)
174 }
175 dir, err = filepath.EvalSymlinks(dir)
176 if err != nil {
177 t.Fatal(err)
178 }
179 defer os.RemoveAll(dir)
180
181 os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
182 symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "subdir", "link"))
183 symlink(t, "/subdir/link", filepath.Join(dir, "path"))
184 symlink(t, "/../../../../../../../../../../../../../../../../self", filepath.Join(dir, "self"))
185
186 for _, test := range []struct {
187 root, unsafe string
188 }{
189 {dir, "subdir/link"},
190 {dir, "path"},
191 {dir, "../../path"},
192 {dir, "subdir/link/../.."},
193 {dir, "../../../../../../../../../../../../../../../../subdir/link/../../../../../../../../../../../../../../../.."},
194 {dir, "self"},
195 {dir, "self/.."},
196 {dir, "/../../../../../../../../../../../../../../../../self/.."},
197 {dir, "/self/././.."},
198 } {
199 got, err := SecureJoin(test.root, test.unsafe)
200 if !errors.Is(err, syscall.ELOOP) {
201 t.Errorf("securejoin(%q, %q): expected ELOOP, got %q & %v", test.root, test.unsafe, got, err)
202 continue
203 }
204 }
205 }
206
207
208 func TestEnotdir(t *testing.T) {
209 dir, err := ioutil.TempDir("", "TestEnotdir")
210 if err != nil {
211 t.Fatal(err)
212 }
213 dir, err = filepath.EvalSymlinks(dir)
214 if err != nil {
215 t.Fatal(err)
216 }
217 defer os.RemoveAll(dir)
218
219 os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
220 ioutil.WriteFile(filepath.Join(dir, "notdir"), []byte("I am not a directory!"), 0755)
221 symlink(t, "/../../../notdir/somechild", filepath.Join(dir, "subdir", "link"))
222
223 for _, test := range []struct {
224 root, unsafe string
225 }{
226 {dir, "subdir/link"},
227 {dir, "notdir"},
228 {dir, "notdir/child"},
229 } {
230 _, err := SecureJoin(test.root, test.unsafe)
231 if err != nil {
232 t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
233 continue
234 }
235 }
236 }
237
238
239 func TestIsNotExist(t *testing.T) {
240 for _, test := range []struct {
241 err error
242 expected bool
243 }{
244 {&os.PathError{Op: "test1", Err: syscall.ENOENT}, true},
245 {&os.LinkError{Op: "test1", Err: syscall.ENOENT}, true},
246 {&os.SyscallError{Syscall: "test1", Err: syscall.ENOENT}, true},
247 {&os.PathError{Op: "test2", Err: syscall.ENOTDIR}, true},
248 {&os.LinkError{Op: "test2", Err: syscall.ENOTDIR}, true},
249 {&os.SyscallError{Syscall: "test2", Err: syscall.ENOTDIR}, true},
250 {&os.PathError{Op: "test3", Err: syscall.EACCES}, false},
251 {&os.LinkError{Op: "test3", Err: syscall.EACCES}, false},
252 {&os.SyscallError{Syscall: "test3", Err: syscall.EACCES}, false},
253 {errors.New("not a proper error"), false},
254 } {
255 got := IsNotExist(test.err)
256 if got != test.expected {
257 t.Errorf("IsNotExist(%#v): expected %v, got %v", test.err, test.expected, got)
258 }
259 }
260 }
261
262 type mockVFS struct {
263 lstat func(path string) (os.FileInfo, error)
264 readlink func(path string) (string, error)
265 }
266
267 func (m mockVFS) Lstat(path string) (os.FileInfo, error) { return m.lstat(path) }
268 func (m mockVFS) Readlink(path string) (string, error) { return m.readlink(path) }
269
270
271 func TestSecureJoinVFS(t *testing.T) {
272 dir, err := ioutil.TempDir("", "TestNonLexical")
273 if err != nil {
274 t.Fatal(err)
275 }
276 dir, err = filepath.EvalSymlinks(dir)
277 if err != nil {
278 t.Fatal(err)
279 }
280 defer os.RemoveAll(dir)
281
282 os.MkdirAll(filepath.Join(dir, "subdir"), 0755)
283 os.MkdirAll(filepath.Join(dir, "cousinparent", "cousin"), 0755)
284 symlink(t, "../cousinparent/cousin", filepath.Join(dir, "subdir", "link"))
285 symlink(t, "/../cousinparent/cousin", filepath.Join(dir, "subdir", "link2"))
286 symlink(t, "/../../../../../../../../../../../../../../../../cousinparent/cousin", filepath.Join(dir, "subdir", "link3"))
287
288 for _, test := range []input{
289 {dir, "subdir", filepath.Join(dir, "subdir")},
290 {dir, "subdir/link/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
291 {dir, "subdir/link2/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
292 {dir, "subdir/link3/test", filepath.Join(dir, "cousinparent", "cousin", "test")},
293 {dir, "subdir/../test", filepath.Join(dir, "test")},
294
295 {dir, "subdir/link/../test", filepath.Join(dir, "cousinparent", "test")},
296 {dir, "subdir/link2/../test", filepath.Join(dir, "cousinparent", "test")},
297 {dir, "subdir/link3/../test", filepath.Join(dir, "cousinparent", "test")},
298 } {
299 var nLstat, nReadlink int
300 mock := mockVFS{
301 lstat: func(path string) (os.FileInfo, error) { nLstat++; return os.Lstat(path) },
302 readlink: func(path string) (string, error) { nReadlink++; return os.Readlink(path) },
303 }
304
305 got, err := SecureJoinVFS(test.root, test.unsafe, mock)
306 if err != nil {
307 t.Errorf("securejoin(%q, %q): unexpected error: %v", test.root, test.unsafe, err)
308 continue
309 }
310 if got != test.expected {
311 t.Errorf("securejoin(%q, %q): expected %q, got %q", test.root, test.unsafe, test.expected, got)
312 continue
313 }
314 if nLstat == 0 && nReadlink == 0 {
315 t.Errorf("securejoin(%q, %q): expected to use either lstat or readlink, neither were used", test.root, test.unsafe)
316 }
317 }
318 }
319
320
321
322 func TestSecureJoinVFSErrors(t *testing.T) {
323 var (
324 lstatErr = errors.New("lstat error")
325 readlinkErr = errors.New("readlink err")
326 )
327
328
329 dir, err := ioutil.TempDir("", "TestSecureJoinVFSErrors")
330 if err != nil {
331 t.Fatal(err)
332 }
333 dir, err = filepath.EvalSymlinks(dir)
334 if err != nil {
335 t.Fatal(err)
336 }
337 defer os.RemoveAll(dir)
338
339
340 symlink(t, "../../../../../../../../../../../../../../../../path", filepath.Join(dir, "link"))
341
342
343 lstatFailFn := func(path string) (os.FileInfo, error) { return nil, lstatErr }
344 readlinkFailFn := func(path string) (string, error) { return "", readlinkErr }
345
346
347 for idx, test := range []struct {
348 vfs VFS
349 expected []error
350 }{
351 {
352 expected: []error{nil},
353 vfs: mockVFS{
354 lstat: os.Lstat,
355 readlink: os.Readlink,
356 },
357 },
358 {
359 expected: []error{lstatErr},
360 vfs: mockVFS{
361 lstat: lstatFailFn,
362 readlink: os.Readlink,
363 },
364 },
365 {
366 expected: []error{readlinkErr},
367 vfs: mockVFS{
368 lstat: os.Lstat,
369 readlink: readlinkFailFn,
370 },
371 },
372 {
373 expected: []error{lstatErr, readlinkErr},
374 vfs: mockVFS{
375 lstat: lstatFailFn,
376 readlink: readlinkFailFn,
377 },
378 },
379 } {
380 _, err := SecureJoinVFS(dir, "link", test.vfs)
381
382 success := false
383 for _, exp := range test.expected {
384 if err == exp {
385 success = true
386 }
387 }
388 if !success {
389 t.Errorf("SecureJoinVFS.mock%d: expected to get lstatError, got %v", idx, err)
390 }
391 }
392 }
393
View as plain text