1
2
3
4 package pmem
5
6 import (
7 "context"
8 "fmt"
9 "os"
10 "testing"
11
12 "github.com/pkg/errors"
13 "golang.org/x/sys/unix"
14
15 "github.com/Microsoft/hcsshim/internal/protocol/guestresource"
16 )
17
18 func clearTestDependencies() {
19 osMkdirAll = nil
20 osRemoveAll = nil
21 unixMount = nil
22 createZeroSectorLinearTarget = nil
23 createVerityTarget = nil
24 removeDevice = nil
25 mountInternal = mount
26 }
27
28 func Test_Mount_Mkdir_Fails_Error(t *testing.T) {
29 clearTestDependencies()
30
31 expectedErr := errors.New("mkdir : no such file or directory")
32 osMkdirAll = func(path string, perm os.FileMode) error {
33 return expectedErr
34 }
35 err := Mount(context.Background(), 0, "", nil, nil)
36 if errors.Cause(err) != expectedErr {
37 t.Fatalf("expected err: %v, got: %v", expectedErr, err)
38 }
39 }
40
41 func Test_Mount_Mkdir_ExpectedPath(t *testing.T) {
42 clearTestDependencies()
43
44
45
46
47 target := "/fake/path"
48 osMkdirAll = func(path string, perm os.FileMode) error {
49 if path != target {
50 t.Errorf("expected path: %v, got: %v", target, path)
51 return errors.New("unexpected path")
52 }
53 return nil
54 }
55 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
56
57 return nil
58 }
59 err := Mount(context.Background(), 0, target, nil, nil)
60 if err != nil {
61 t.Fatalf("expected nil error got: %v", err)
62 }
63 }
64
65 func Test_Mount_Mkdir_ExpectedPerm(t *testing.T) {
66 clearTestDependencies()
67
68
69
70
71 target := "/fake/path"
72 osMkdirAll = func(path string, perm os.FileMode) error {
73 if perm != os.FileMode(0700) {
74 t.Errorf("expected perm: %v, got: %v", os.FileMode(0700), perm)
75 return errors.New("unexpected perm")
76 }
77 return nil
78 }
79 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
80
81 return nil
82 }
83 err := Mount(context.Background(), 0, target, nil, nil)
84 if err != nil {
85 t.Fatalf("expected nil error got: %v", err)
86 }
87 }
88
89 func Test_Mount_Calls_RemoveAll_OnMountFailure(t *testing.T) {
90 clearTestDependencies()
91
92 osMkdirAll = func(path string, perm os.FileMode) error {
93 return nil
94 }
95 target := "/fake/path"
96 removeAllCalled := false
97 osRemoveAll = func(path string) error {
98 removeAllCalled = true
99 if path != target {
100 t.Errorf("expected path: %v, got: %v", target, path)
101 return errors.New("unexpected path")
102 }
103 return nil
104 }
105 expectedErr := errors.New("unexpected mount failure")
106 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
107
108 return expectedErr
109 }
110 err := Mount(context.Background(), 0, target, nil, nil)
111 if errors.Cause(err) != expectedErr {
112 t.Fatalf("expected err: %v, got: %v", expectedErr, err)
113 }
114 if !removeAllCalled {
115 t.Fatal("expected os.RemoveAll to be called on mount failure")
116 }
117 }
118
119 func Test_Mount_Valid_Source(t *testing.T) {
120 clearTestDependencies()
121
122
123
124
125 osMkdirAll = func(path string, perm os.FileMode) error {
126 return nil
127 }
128 device := uint32(20)
129 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
130 expected := fmt.Sprintf("/dev/pmem%d", device)
131 if source != expected {
132 t.Errorf("expected source: %s, got: %s", expected, source)
133 return errors.New("unexpected source")
134 }
135 return nil
136 }
137 err := Mount(context.Background(), device, "/fake/path", nil, nil)
138 if err != nil {
139 t.Fatalf("expected nil err, got: %v", err)
140 }
141 }
142
143 func Test_Mount_Valid_Target(t *testing.T) {
144 clearTestDependencies()
145
146
147
148
149 osMkdirAll = func(path string, perm os.FileMode) error {
150 return nil
151 }
152 expectedTarget := "/fake/path"
153 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
154 if expectedTarget != target {
155 t.Errorf("expected target: %s, got: %s", expectedTarget, target)
156 return errors.New("unexpected target")
157 }
158 return nil
159 }
160 err := Mount(context.Background(), 0, expectedTarget, nil, nil)
161 if err != nil {
162 t.Fatalf("expected nil err, got: %v", err)
163 }
164 }
165
166 func Test_Mount_Valid_FSType(t *testing.T) {
167 clearTestDependencies()
168
169
170
171
172 osMkdirAll = func(path string, perm os.FileMode) error {
173 return nil
174 }
175 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
176 expectedFSType := "ext4"
177 if expectedFSType != fstype {
178 t.Errorf("expected fstype: %s, got: %s", expectedFSType, fstype)
179 return errors.New("unexpected fstype")
180 }
181 return nil
182 }
183 err := Mount(context.Background(), 0, "/fake/path", nil, nil)
184 if err != nil {
185 t.Fatalf("expected nil err, got: %v", err)
186 }
187 }
188
189 func Test_Mount_Valid_Flags(t *testing.T) {
190 clearTestDependencies()
191
192
193
194
195 osMkdirAll = func(path string, perm os.FileMode) error {
196 return nil
197 }
198 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
199 expectedFlags := uintptr(unix.MS_RDONLY)
200 if expectedFlags != flags {
201 t.Errorf("expected flags: %v, got: %v", expectedFlags, flags)
202 return errors.New("unexpected flags")
203 }
204 return nil
205 }
206 err := Mount(context.Background(), 0, "/fake/path", nil, nil)
207 if err != nil {
208 t.Fatalf("expected nil err, got: %v", err)
209 }
210 }
211
212 func Test_Mount_Valid_Data(t *testing.T) {
213 clearTestDependencies()
214
215
216
217
218 osMkdirAll = func(path string, perm os.FileMode) error {
219 return nil
220 }
221 unixMount = func(source string, target string, fstype string, flags uintptr, data string) error {
222 expectedData := "noload"
223 if expectedData != data {
224 t.Errorf("expected data: %s, got: %s", expectedData, data)
225 return errors.New("unexpected data")
226 }
227 return nil
228 }
229 err := Mount(context.Background(), 0, "/fake/path", nil, nil)
230 if err != nil {
231 t.Fatalf("expected nil err, got: %v", err)
232 }
233 }
234
235
236 func Test_CreateLinearTarget_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
237 clearTestDependencies()
238
239 mappingInfo := &guestresource.LCOWVPMemMappingInfo{
240 DeviceOffsetInBytes: 0,
241 DeviceSizeInBytes: 1024,
242 }
243 expectedLinearName := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
244 expectedSource := "/dev/pmem0"
245 expectedTarget := "/foo"
246 mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearName)
247 createZSLTCalled := false
248
249 osMkdirAll = func(_ string, _ os.FileMode) error {
250 return nil
251 }
252
253 mountInternal = func(_ context.Context, source, target string) error {
254 if source != mapperPath {
255 t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
256 }
257 if target != expectedTarget {
258 t.Errorf("expected mountInternal target %s, got %s", expectedTarget, source)
259 }
260 return nil
261 }
262
263 createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
264 createZSLTCalled = true
265 if source != expectedSource {
266 t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedSource, source)
267 }
268 if name != expectedLinearName {
269 t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearName, name)
270 }
271 return mapperPath, nil
272 }
273
274 if err := Mount(
275 context.Background(),
276 0,
277 expectedTarget,
278 mappingInfo,
279 nil,
280 ); err != nil {
281 t.Fatalf("unexpected error during Mount: %s", err)
282 }
283 if !createZSLTCalled {
284 t.Fatalf("createZeroSectorLinearTarget not called")
285 }
286 }
287
288 func Test_CreateVerityTargetCalled_And_Mount_Called_With_Correct_Parameters(t *testing.T) {
289 clearTestDependencies()
290
291 verityInfo := &guestresource.DeviceVerityInfo{
292 RootDigest: "hash",
293 }
294 expectedVerityName := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
295 expectedSource := "/dev/pmem0"
296 expectedTarget := "/foo"
297 mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityName)
298 createVerityTargetCalled := false
299
300 mountInternal = func(_ context.Context, source, target string) error {
301 if source != mapperPath {
302 t.Errorf("expected mountInternal source %s, got %s", mapperPath, source)
303 }
304 if target != expectedTarget {
305 t.Errorf("expected mountInternal target %s, got %s", expectedTarget, target)
306 }
307 return nil
308 }
309 createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
310 createVerityTargetCalled = true
311 if source != expectedSource {
312 t.Errorf("expected createVerityTarget source %s, got %s", expectedSource, source)
313 }
314 if name != expectedVerityName {
315 t.Errorf("expected createVerityTarget name %s, got %s", expectedVerityName, name)
316 }
317 return mapperPath, nil
318 }
319
320 if err := Mount(
321 context.Background(),
322 0,
323 expectedTarget,
324 nil,
325 verityInfo,
326 ); err != nil {
327 t.Fatalf("unexpected Mount failure: %s", err)
328 }
329 if !createVerityTargetCalled {
330 t.Fatal("createVerityTarget not called")
331 }
332 }
333
334 func Test_CreateLinearTarget_And_CreateVerityTargetCalled_Called_Correctly(t *testing.T) {
335 clearTestDependencies()
336
337 verityInfo := &guestresource.DeviceVerityInfo{
338 RootDigest: "hash",
339 }
340 mapping := &guestresource.LCOWVPMemMappingInfo{
341 DeviceOffsetInBytes: 0,
342 DeviceSizeInBytes: 1024,
343 }
344 expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
345 expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verityInfo.RootDigest)
346 expectedPMemDevice := "/dev/pmem0"
347 mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
348 mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
349 dmLinearCalled := false
350 dmVerityCalled := false
351 mountCalled := false
352
353 createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
354 dmLinearCalled = true
355 if source != expectedPMemDevice {
356 t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
357 }
358 if name != expectedLinearTarget {
359 t.Errorf("expected createZeroSectorLinearTarget name %s, got %s", expectedLinearTarget, name)
360 }
361 return mapperLinearPath, nil
362 }
363 createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
364 dmVerityCalled = true
365 if source != mapperLinearPath {
366 t.Errorf("expected createVerityTarget source %s, got %s", mapperLinearPath, source)
367 }
368 if name != expectedVerityTarget {
369 t.Errorf("expected createVerityTarget target name %s, got %s", expectedVerityTarget, name)
370 }
371 return mapperVerityPath, nil
372 }
373 mountInternal = func(_ context.Context, source, target string) error {
374 mountCalled = true
375 if source != mapperVerityPath {
376 t.Errorf("expected Mount source %s, got %s", mapperVerityPath, source)
377 }
378 return nil
379 }
380
381 if err := Mount(
382 context.Background(),
383 0,
384 "/foo",
385 mapping,
386 verityInfo,
387 ); err != nil {
388 t.Fatalf("unexpected error during Mount call: %s", err)
389 }
390 if !dmLinearCalled {
391 t.Fatal("expected createZeroSectorLinearTarget call")
392 }
393 if !dmVerityCalled {
394 t.Fatal("expected createVerityTarget call")
395 }
396 if !mountCalled {
397 t.Fatal("expected mountInternal call")
398 }
399 }
400
401 func Test_RemoveDevice_Called_For_LinearTarget_On_MountInternalFailure(t *testing.T) {
402 clearTestDependencies()
403
404 mappingInfo := &guestresource.LCOWVPMemMappingInfo{
405 DeviceOffsetInBytes: 0,
406 DeviceSizeInBytes: 1024,
407 }
408 expectedError := errors.New("mountInternal error")
409 expectedTarget := fmt.Sprintf(linearDeviceFmt, 0, mappingInfo.DeviceOffsetInBytes, mappingInfo.DeviceSizeInBytes)
410 mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedTarget)
411 removeDeviceCalled := false
412
413 createZeroSectorLinearTarget = func(_ context.Context, source, name string, mapping *guestresource.LCOWVPMemMappingInfo) (string, error) {
414 return mapperPath, nil
415 }
416 mountInternal = func(_ context.Context, source, target string) error {
417 return expectedError
418 }
419 removeDevice = func(name string) error {
420 removeDeviceCalled = true
421 if name != expectedTarget {
422 t.Errorf("expected removeDevice linear target %s, got %s", expectedTarget, name)
423 }
424 return nil
425 }
426
427 if err := Mount(
428 context.Background(),
429 0,
430 "/foo",
431 mappingInfo,
432 nil,
433 ); err != expectedError {
434 t.Fatalf("expected Mount error %s, got %s", expectedError, err)
435 }
436 if !removeDeviceCalled {
437 t.Fatal("expected removeDevice to be callled")
438 }
439 }
440
441 func Test_RemoveDevice_Called_For_VerityTarget_On_MountInternalFailure(t *testing.T) {
442 clearTestDependencies()
443
444 verity := &guestresource.DeviceVerityInfo{
445 RootDigest: "hash",
446 }
447 expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
448 expectedError := errors.New("mountInternal error")
449 mapperPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
450 removeDeviceCalled := false
451
452 createVerityTarget = func(_ context.Context, source, name string, verity *guestresource.DeviceVerityInfo) (string, error) {
453 return mapperPath, nil
454 }
455 mountInternal = func(_ context.Context, _, _ string) error {
456 return expectedError
457 }
458 removeDevice = func(name string) error {
459 removeDeviceCalled = true
460 if name != expectedVerityTarget {
461 t.Errorf("expected removeDevice verity target %s, got %s", expectedVerityTarget, name)
462 }
463 return nil
464 }
465
466 if err := Mount(
467 context.Background(),
468 0,
469 "/foo",
470 nil,
471 verity,
472 ); err != expectedError {
473 t.Fatalf("expected Mount error %s, got %s", expectedError, err)
474 }
475 if !removeDeviceCalled {
476 t.Fatal("expected removeDevice to be called")
477 }
478 }
479
480 func Test_RemoveDevice_Called_For_Both_Targets_On_MountInternalFailure(t *testing.T) {
481 clearTestDependencies()
482
483 mapping := &guestresource.LCOWVPMemMappingInfo{
484 DeviceOffsetInBytes: 0,
485 DeviceSizeInBytes: 1024,
486 }
487 verity := &guestresource.DeviceVerityInfo{
488 RootDigest: "hash",
489 }
490 expectedError := errors.New("mountInternal error")
491 expectedLinearTarget := fmt.Sprintf(linearDeviceFmt, 0, mapping.DeviceOffsetInBytes, mapping.DeviceSizeInBytes)
492 expectedVerityTarget := fmt.Sprintf(verityDeviceFmt, 0, verity.RootDigest)
493 expectedPMemDevice := "/dev/pmem0"
494 mapperLinearPath := fmt.Sprintf("/dev/mapper/%s", expectedLinearTarget)
495 mapperVerityPath := fmt.Sprintf("/dev/mapper/%s", expectedVerityTarget)
496 rmLinearCalled := false
497 rmVerityCalled := false
498
499 createZeroSectorLinearTarget = func(_ context.Context, source, name string, m *guestresource.LCOWVPMemMappingInfo) (string, error) {
500 if source != expectedPMemDevice {
501 t.Errorf("expected createZeroSectorLinearTarget source %s, got %s", expectedPMemDevice, source)
502 }
503 return mapperLinearPath, nil
504 }
505 createVerityTarget = func(_ context.Context, source, name string, v *guestresource.DeviceVerityInfo) (string, error) {
506 if source != mapperLinearPath {
507 t.Errorf("expected createVerityTarget to be called with %s, got %s", mapperLinearPath, source)
508 }
509 if name != expectedVerityTarget {
510 t.Errorf("expected createVerityTarget target %s, got %s", expectedVerityTarget, name)
511 }
512 return mapperVerityPath, nil
513 }
514 removeDevice = func(name string) error {
515 if name != expectedLinearTarget && name != expectedVerityTarget {
516 t.Errorf("unexpected removeDevice target name %s", name)
517 }
518 if name == expectedLinearTarget {
519 rmLinearCalled = true
520 }
521 if name == expectedVerityTarget {
522 rmVerityCalled = true
523 }
524 return nil
525 }
526 mountInternal = func(_ context.Context, _, _ string) error {
527 return expectedError
528 }
529
530 if err := Mount(
531 context.Background(),
532 0,
533 "/foo",
534 mapping,
535 verity,
536 ); err != expectedError {
537 t.Fatalf("expected Mount error %s, got %s", expectedError, err)
538 }
539 if !rmLinearCalled {
540 t.Fatal("expected removeDevice for linear target to be called")
541 }
542 if !rmVerityCalled {
543 t.Fatal("expected removeDevice for verity target to be called")
544 }
545 }
546
View as plain text