...
1
2
3 package netlink
4
5 import (
6 "os"
7 "runtime"
8 "syscall"
9 "testing"
10
11 "github.com/vishvananda/netns"
12 )
13
14
15
16
17
18 func TestNetNsIdByFd(t *testing.T) {
19 skipUnlessRoot(t)
20
21 ns, err := netns.New()
22 CheckErrorFail(t, err)
23
24
25
26
27
28 wantID := os.Getpid() << 16
29
30 h, err := NewHandle()
31 CheckErrorFail(t, err)
32 err = h.SetNetNsIdByFd(int(ns), wantID)
33 CheckErrorFail(t, err)
34
35
36 haveID, _ := h.GetNetNsIdByFd(int(ns))
37 if haveID != wantID {
38 t.Errorf("GetNetNsIdByFd returned %d, want %d", haveID, wantID)
39 }
40
41 ns.Close()
42 }
43
44
45
46
47 func TestNetNsIdByPid(t *testing.T) {
48 skipUnlessRoot(t)
49 runtime.LockOSThread()
50 origNs, _ := netns.Get()
51
52
53 ns, err := netns.New()
54 CheckErrorFail(t, err)
55 err = netns.Set(ns)
56 CheckErrorFail(t, err)
57
58 defer func() {
59 err := netns.Set(origNs)
60 if err != nil {
61 panic("failed to restore network ns, bailing")
62 }
63 runtime.UnlockOSThread()
64 }()
65
66
67 wantID := syscall.Gettid() << 16
68
69 h, err := NewHandle()
70 CheckErrorFail(t, err)
71 err = h.SetNetNsIdByPid(syscall.Gettid(), wantID)
72 CheckErrorFail(t, err)
73
74
75 haveID, _ := h.GetNetNsIdByPid(syscall.Gettid())
76 if haveID != wantID {
77 t.Errorf("GetNetNsIdByPid returned %d, want %d", haveID, wantID)
78 }
79 }
80
View as plain text