...

Source file src/github.com/cilium/ebpf/cmd/bpf2go/main_test.go

Documentation: github.com/cilium/ebpf/cmd/bpf2go

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"os/exec"
     9  	"path/filepath"
    10  	"runtime"
    11  	"sort"
    12  	"strings"
    13  	"testing"
    14  
    15  	qt "github.com/frankban/quicktest"
    16  	"github.com/google/go-cmp/cmp"
    17  )
    18  
    19  func TestRun(t *testing.T) {
    20  	if testing.Short() {
    21  		t.Skip("Not compiling with -short")
    22  	}
    23  
    24  	dir := mustWriteTempFile(t, "test.c", minimalSocketFilter)
    25  
    26  	cwd, err := os.Getwd()
    27  	if err != nil {
    28  		t.Fatal(err)
    29  	}
    30  
    31  	modRoot := filepath.Clean(filepath.Join(cwd, "../.."))
    32  	if _, err := os.Stat(filepath.Join(modRoot, "go.mod")); os.IsNotExist(err) {
    33  		t.Fatal("No go.mod file in", modRoot)
    34  	}
    35  
    36  	tmpDir, err := os.MkdirTemp("", "bpf2go-module-*")
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  	defer os.RemoveAll(tmpDir)
    41  
    42  	execInModule := func(name string, args ...string) {
    43  		t.Helper()
    44  
    45  		cmd := exec.Command(name, args...)
    46  		cmd.Dir = tmpDir
    47  		if out, err := cmd.CombinedOutput(); err != nil {
    48  			if out := string(out); out != "" {
    49  				t.Log(out)
    50  			}
    51  			t.Fatalf("Can't execute %s: %v", name, args)
    52  		}
    53  	}
    54  
    55  	execInModule("go", "mod", "init", "bpf2go-test")
    56  
    57  	execInModule("go", "mod", "edit",
    58  		// Require the module. The version doesn't matter due to the replace
    59  		// below.
    60  		fmt.Sprintf("-require=%s@v0.0.0", ebpfModule),
    61  		// Replace the module with the current version.
    62  		fmt.Sprintf("-replace=%s=%s", ebpfModule, modRoot),
    63  	)
    64  
    65  	err = run(io.Discard, "foo", tmpDir, []string{
    66  		"-cc", clangBin,
    67  		"bar",
    68  		filepath.Join(dir, "test.c"),
    69  	})
    70  
    71  	if err != nil {
    72  		t.Fatal("Can't run:", err)
    73  	}
    74  
    75  	for _, arch := range []string{
    76  		"amd64", // little-endian
    77  		"s390x", // big-endian
    78  	} {
    79  		t.Run(arch, func(t *testing.T) {
    80  			goBin := exec.Command("go", "build", "-mod=mod")
    81  			goBin.Dir = tmpDir
    82  			goBin.Env = append(os.Environ(),
    83  				"GOOS=linux",
    84  				"GOARCH="+arch,
    85  			)
    86  			out, err := goBin.CombinedOutput()
    87  			if err != nil {
    88  				if out := string(out); out != "" {
    89  					t.Log(out)
    90  				}
    91  				t.Error("Can't compile package:", err)
    92  			}
    93  		})
    94  	}
    95  }
    96  
    97  func TestHelp(t *testing.T) {
    98  	var stdout bytes.Buffer
    99  	err := run(&stdout, "", "", []string{"-help"})
   100  	if err != nil {
   101  		t.Fatal("Can't execute -help")
   102  	}
   103  
   104  	if stdout.Len() == 0 {
   105  		t.Error("-help doesn't write to stdout")
   106  	}
   107  }
   108  
   109  func TestDisableStripping(t *testing.T) {
   110  	dir := mustWriteTempFile(t, "test.c", minimalSocketFilter)
   111  
   112  	err := run(io.Discard, "foo", dir, []string{
   113  		"-cc", "clang-9",
   114  		"-strip", "binary-that-certainly-doesnt-exist",
   115  		"-no-strip",
   116  		"bar",
   117  		filepath.Join(dir, "test.c"),
   118  	})
   119  
   120  	if err != nil {
   121  		t.Fatal("Can't run with stripping disabled:", err)
   122  	}
   123  }
   124  
   125  func TestCollectTargets(t *testing.T) {
   126  	clangArches := make(map[string][]string)
   127  	linuxArchesLE := make(map[string][]string)
   128  	linuxArchesBE := make(map[string][]string)
   129  	for arch, archTarget := range targetByGoArch {
   130  		clangArches[archTarget.clang] = append(clangArches[archTarget.clang], arch)
   131  		if archTarget.clang == "bpfel" {
   132  			linuxArchesLE[archTarget.linux] = append(linuxArchesLE[archTarget.linux], arch)
   133  			continue
   134  		}
   135  		linuxArchesBE[archTarget.linux] = append(linuxArchesBE[archTarget.linux], arch)
   136  	}
   137  	for i := range clangArches {
   138  		sort.Strings(clangArches[i])
   139  	}
   140  	for i := range linuxArchesLE {
   141  		sort.Strings(linuxArchesLE[i])
   142  	}
   143  	for i := range linuxArchesBE {
   144  		sort.Strings(linuxArchesBE[i])
   145  	}
   146  
   147  	nativeTarget := make(map[target][]string)
   148  	for arch, archTarget := range targetByGoArch {
   149  		if arch == runtime.GOARCH {
   150  			if archTarget.clang == "bpfel" {
   151  				nativeTarget[archTarget] = linuxArchesLE[archTarget.linux]
   152  			} else {
   153  				nativeTarget[archTarget] = linuxArchesBE[archTarget.linux]
   154  			}
   155  			break
   156  		}
   157  	}
   158  
   159  	tests := []struct {
   160  		targets []string
   161  		want    map[target][]string
   162  	}{
   163  		{
   164  			[]string{"bpf", "bpfel", "bpfeb"},
   165  			map[target][]string{
   166  				{"bpf", ""}:   nil,
   167  				{"bpfel", ""}: clangArches["bpfel"],
   168  				{"bpfeb", ""}: clangArches["bpfeb"],
   169  			},
   170  		},
   171  		{
   172  			[]string{"amd64", "386"},
   173  			map[target][]string{
   174  				{"bpfel", "x86"}: linuxArchesLE["x86"],
   175  			},
   176  		},
   177  		{
   178  			[]string{"amd64", "arm64be"},
   179  			map[target][]string{
   180  				{"bpfeb", "arm64"}: linuxArchesBE["arm64"],
   181  				{"bpfel", "x86"}:   linuxArchesLE["x86"],
   182  			},
   183  		},
   184  		{
   185  			[]string{"native"},
   186  			nativeTarget,
   187  		},
   188  	}
   189  
   190  	for _, test := range tests {
   191  		name := strings.Join(test.targets, ",")
   192  		t.Run(name, func(t *testing.T) {
   193  			have, err := collectTargets(test.targets)
   194  			if err != nil {
   195  				t.Fatal(err)
   196  			}
   197  
   198  			if diff := cmp.Diff(test.want, have); diff != "" {
   199  				t.Errorf("Result mismatch (-want +got):\n%s", diff)
   200  			}
   201  		})
   202  	}
   203  }
   204  
   205  func TestCollectTargetsErrors(t *testing.T) {
   206  	tests := []struct {
   207  		name   string
   208  		target string
   209  	}{
   210  		{"unknown", "frood"},
   211  		{"no linux target", "mips64p32le"},
   212  	}
   213  
   214  	for _, test := range tests {
   215  		t.Run(test.name, func(t *testing.T) {
   216  			_, err := collectTargets([]string{test.target})
   217  			if err == nil {
   218  				t.Fatal("Function did not return an error")
   219  			}
   220  			t.Log("Error message:", err)
   221  		})
   222  	}
   223  }
   224  
   225  func TestConvertGOARCH(t *testing.T) {
   226  	tmp := mustWriteTempFile(t, "test.c",
   227  		`
   228  #ifndef __TARGET_ARCH_x86
   229  #error __TARGET_ARCH_x86 is not defined
   230  #endif`,
   231  	)
   232  
   233  	b2g := bpf2go{
   234  		pkg:              "test",
   235  		stdout:           io.Discard,
   236  		ident:            "test",
   237  		cc:               clangBin,
   238  		disableStripping: true,
   239  		sourceFile:       tmp + "/test.c",
   240  		outputDir:        tmp,
   241  	}
   242  
   243  	if err := b2g.convert(targetByGoArch["amd64"], nil); err != nil {
   244  		t.Fatal("Can't target GOARCH:", err)
   245  	}
   246  }
   247  
   248  func TestCTypes(t *testing.T) {
   249  	var ct cTypes
   250  	valid := []string{
   251  		"abcdefghijklmnopqrstuvqxyABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_",
   252  		"y",
   253  	}
   254  	for _, value := range valid {
   255  		if err := ct.Set(value); err != nil {
   256  			t.Fatalf("Set returned an error for %q: %s", value, err)
   257  		}
   258  	}
   259  	qt.Assert(t, ct, qt.ContentEquals, cTypes(valid))
   260  
   261  	for _, value := range []string{
   262  		"",
   263  		" ",
   264  		" frood",
   265  		"foo\nbar",
   266  		".",
   267  		",",
   268  		"+",
   269  		"-",
   270  	} {
   271  		ct = nil
   272  		if err := ct.Set(value); err == nil {
   273  			t.Fatalf("Set did not return an error for %q", value)
   274  		}
   275  	}
   276  
   277  	ct = nil
   278  	qt.Assert(t, ct.Set("foo"), qt.IsNil)
   279  	qt.Assert(t, ct.Set("foo"), qt.IsNotNil)
   280  }
   281  

View as plain text