1
2
3
4 package netlink
5
6 import (
7 "bytes"
8 "crypto/rand"
9 "encoding/hex"
10 "fmt"
11 "io/ioutil"
12 "log"
13 "os"
14 "os/exec"
15 "runtime"
16 "strings"
17 "testing"
18
19 "github.com/vishvananda/netlink/nl"
20 "github.com/vishvananda/netns"
21 "golang.org/x/sys/unix"
22 )
23
24 type tearDownNetlinkTest func()
25
26 func skipUnlessRoot(t testing.TB) {
27 t.Helper()
28
29 if os.Getuid() != 0 {
30 t.Skip("Test requires root privileges.")
31 }
32 }
33
34 func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) {
35 t.Helper()
36 file, err := ioutil.ReadFile("/proc/modules")
37 if err != nil {
38 t.Fatal("Failed to open /proc/modules", err)
39 }
40
41 foundRequiredMods := make(map[string]bool)
42 lines := strings.Split(string(file), "\n")
43
44 for _, name := range moduleNames {
45 foundRequiredMods[name] = false
46 for _, line := range lines {
47 n := strings.Split(line, " ")[0]
48 if n == name {
49 foundRequiredMods[name] = true
50 break
51 }
52 }
53 }
54
55 failed := false
56 for _, name := range moduleNames {
57 if found, _ := foundRequiredMods[name]; !found {
58 t.Logf("Test requires missing kmodule %q.", name)
59 failed = true
60 }
61 }
62 if failed {
63 t.SkipNow()
64 }
65 }
66
67 func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
68 skipUnlessRoot(t)
69
70
71
72 runtime.LockOSThread()
73 var err error
74 ns, err := netns.New()
75 if err != nil {
76 t.Fatal("Failed to create newns", ns)
77 }
78
79 return func() {
80 ns.Close()
81 runtime.UnlockOSThread()
82 }
83 }
84
85
86 func setUpNamedNetlinkTest(t *testing.T) (string, tearDownNetlinkTest) {
87 skipUnlessRoot(t)
88
89 origNS, err := netns.Get()
90 if err != nil {
91 t.Fatal("Failed saving orig namespace")
92 }
93
94
95 rnd := make([]byte, 4)
96 if _, err := rand.Read(rnd); err != nil {
97 t.Fatal("failed creating random ns name")
98 }
99 name := "netlinktest-" + hex.EncodeToString(rnd)
100
101 ns, err := netns.NewNamed(name)
102 if err != nil {
103 t.Fatal("Failed to create new ns", err)
104 }
105
106 runtime.LockOSThread()
107 cleanup := func() {
108 ns.Close()
109 netns.DeleteNamed(name)
110 netns.Set(origNS)
111 runtime.UnlockOSThread()
112 }
113
114 if err := netns.Set(ns); err != nil {
115 cleanup()
116 t.Fatal("Failed entering new namespace", err)
117 }
118
119 return name, cleanup
120 }
121
122 func setUpNetlinkTestWithLoopback(t *testing.T) tearDownNetlinkTest {
123 skipUnlessRoot(t)
124
125 runtime.LockOSThread()
126 ns, err := netns.New()
127 if err != nil {
128 t.Fatal("Failed to create new netns", ns)
129 }
130
131 link, err := LinkByName("lo")
132 if err != nil {
133 t.Fatalf("Failed to find \"lo\" in new netns: %v", err)
134 }
135 if err := LinkSetUp(link); err != nil {
136 t.Fatalf("Failed to bring up \"lo\" in new netns: %v", err)
137 }
138
139 return func() {
140 ns.Close()
141 runtime.UnlockOSThread()
142 }
143 }
144
145 func setUpF(t *testing.T, path, value string) {
146 file, err := os.Create(path)
147 if err != nil {
148 t.Fatalf("Failed to open %s: %s", path, err)
149 }
150 defer file.Close()
151 file.WriteString(value)
152 }
153
154 func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest {
155 if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil {
156 t.Skip("Test requires MPLS support.")
157 }
158 f := setUpNetlinkTest(t)
159 setUpF(t, "/proc/sys/net/mpls/platform_labels", "1024")
160 setUpF(t, "/proc/sys/net/mpls/conf/lo/input", "1")
161 return f
162 }
163
164 func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest {
165
166 cmd := exec.Command("uname", "-r")
167 var out bytes.Buffer
168 cmd.Stdout = &out
169 if err := cmd.Run(); err != nil {
170 t.Fatal("Failed to run: uname -r")
171 }
172 s := []string{"/boot/config-", strings.TrimRight(out.String(), "\n")}
173 filename := strings.Join(s, "")
174
175 grepKey := func(key, fname string) (string, error) {
176 cmd := exec.Command("grep", key, filename)
177 var out bytes.Buffer
178 cmd.Stdout = &out
179 err := cmd.Run()
180 return strings.TrimRight(out.String(), "\n"), err
181 }
182 key := string("CONFIG_IPV6_SEG6_LWTUNNEL=y")
183 if _, err := grepKey(key, filename); err != nil {
184 msg := "Skipped test because it requires SEG6_LWTUNNEL support."
185 log.Println(msg)
186 t.Skip(msg)
187 }
188
189
190
191 return setUpNetlinkTest(t)
192 }
193
194 func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest {
195 skipUnlessKModuleLoaded(t, moduleNames...)
196 return setUpNetlinkTest(t)
197 }
198 func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) {
199 file, err := ioutil.ReadFile("/proc/modules")
200 if err != nil {
201 t.Fatal("Failed to open /proc/modules", err)
202 }
203
204 foundRequiredMods := make(map[string]bool)
205 lines := strings.Split(string(file), "\n")
206
207 for _, name := range moduleNames {
208 foundRequiredMods[name] = false
209 for _, line := range lines {
210 n := strings.Split(line, " ")[0]
211 if n == name {
212 foundRequiredMods[name] = true
213 break
214 }
215 }
216 }
217
218 failed := false
219 for _, name := range moduleNames {
220 if found, _ := foundRequiredMods[name]; !found {
221 t.Logf("Test requires missing kmodule %q.", name)
222 failed = true
223 }
224 }
225 if failed {
226 t.SkipNow()
227 }
228
229 return setUpNamedNetlinkTest(t)
230 }
231
232 func remountSysfs() error {
233 if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
234 return err
235 }
236 if err := unix.Unmount("/sys", unix.MNT_DETACH); err != nil {
237 return err
238 }
239 return unix.Mount("", "/sys", "sysfs", 0, "")
240 }
241
242 func minKernelRequired(t *testing.T, kernel, major int) {
243 t.Helper()
244
245 k, m, err := KernelVersion()
246 if err != nil {
247 t.Fatal(err)
248 }
249 if k < kernel || k == kernel && m < major {
250 t.Skipf("Host Kernel (%d.%d) does not meet test's minimum required version: (%d.%d)",
251 k, m, kernel, major)
252 }
253 }
254
255 func KernelVersion() (kernel, major int, err error) {
256 uts := unix.Utsname{}
257 if err = unix.Uname(&uts); err != nil {
258 return
259 }
260
261 ba := make([]byte, 0, len(uts.Release))
262 for _, b := range uts.Release {
263 if b == 0 {
264 break
265 }
266 ba = append(ba, byte(b))
267 }
268 var rest string
269 if n, _ := fmt.Sscanf(string(ba), "%d.%d%s", &kernel, &major, &rest); n < 2 {
270 err = fmt.Errorf("can't parse kernel version in %q", string(ba))
271 }
272 return
273 }
274
275 func TestMain(m *testing.M) {
276 nl.EnableErrorMessageReporting = true
277 os.Exit(m.Run())
278 }
279
View as plain text