1
2
3
4 package socket_test
5
6 import (
7 "context"
8 "errors"
9 "fmt"
10 "math"
11 "net"
12 "os"
13 "runtime"
14 "testing"
15
16 "github.com/google/go-cmp/cmp"
17 "github.com/mdlayher/socket"
18 "github.com/mdlayher/socket/internal/sockettest"
19 "golang.org/x/sync/errgroup"
20 "golang.org/x/sys/unix"
21 )
22
23 func TestLinuxConnBuffers(t *testing.T) {
24 t.Parallel()
25
26
27
28
29 c, err := socket.Socket(unix.AF_INET, unix.SOCK_STREAM, 0, "tcpv4", nil)
30 if err != nil {
31 t.Fatalf("failed to open socket: %v", err)
32 }
33 defer c.Close()
34
35 const (
36 set = 8192
37
38
39
40
41
42
43 want = set * 2
44 )
45
46 if err := c.SetReadBuffer(set); err != nil {
47 t.Fatalf("failed to set read buffer size: %v", err)
48 }
49
50 if err := c.SetWriteBuffer(set); err != nil {
51 t.Fatalf("failed to set write buffer size: %v", err)
52 }
53
54
55
56
57 rcv, err := c.ReadBuffer()
58 if err != nil {
59 t.Fatalf("failed to get read buffer size: %v", err)
60 }
61
62 snd, err := c.WriteBuffer()
63 if err != nil {
64 t.Fatalf("failed to get write buffer size: %v", err)
65 }
66
67 if diff := cmp.Diff(want, rcv); diff != "" {
68 t.Fatalf("unexpected read buffer size (-want +got):\n%s", diff)
69 }
70 if diff := cmp.Diff(want, snd); diff != "" {
71 t.Fatalf("unexpected write buffer size (-want +got):\n%s", diff)
72 }
73 }
74
75 func TestLinuxNetworkNamespaces(t *testing.T) {
76 t.Parallel()
77
78 l, err := sockettest.Listen(0, nil)
79 if err != nil {
80 t.Fatalf("failed to create listener: %v", err)
81 }
82 defer l.Close()
83
84 addrC := make(chan net.Addr, 1)
85
86 var eg errgroup.Group
87 eg.Go(func() error {
88
89
90
91 runtime.LockOSThread()
92
93 if err := unix.Unshare(unix.CLONE_NEWNET); err != nil {
94
95 return fmt.Errorf("failed to unshare network namespace: %w", err)
96 }
97
98 ns, err := socket.ThreadNetNS()
99 if err != nil {
100 return fmt.Errorf("failed to get listener thread's network namespace: %v", err)
101 }
102
103
104
105 l, err := sockettest.Listen(
106 l.Addr().(*net.TCPAddr).Port,
107 &socket.Config{NetNS: ns.FD()},
108 )
109 if err != nil {
110 return fmt.Errorf("failed to create listener in network namespace: %v", err)
111 }
112 defer l.Close()
113
114 addrC <- l.Addr()
115 return nil
116 })
117
118 if err := eg.Wait(); err != nil {
119 if errors.Is(err, os.ErrPermission) {
120 t.Skipf("skipping, permission denied: %v", err)
121 }
122
123 t.Fatalf("failed to run listener thread: %v", err)
124 }
125
126 select {
127 case addr := <-addrC:
128 if diff := cmp.Diff(l.Addr(), addr); diff != "" {
129 t.Fatalf("unexpected network address (-want +got):\n%s", diff)
130 }
131 default:
132 t.Fatal("listener thread did not return its local address")
133 }
134 }
135
136 func TestLinuxDialVsockNoListener(t *testing.T) {
137 t.Parallel()
138
139
140
141 c, err := socket.Socket(unix.AF_VSOCK, unix.SOCK_STREAM, 0, "vsock", nil)
142 if err != nil {
143 t.Fatalf("failed to open socket: %v", err)
144 }
145 defer c.Close()
146
147
148
149 _, err = c.Connect(context.Background(), &unix.SockaddrVM{
150 CID: unix.VMADDR_CID_LOCAL,
151 Port: math.MaxUint32,
152 })
153 if err == nil {
154
155 t.Skipf("skipping, expected error but vsock successfully connected to local service")
156 }
157
158 want := os.NewSyscallError("connect", unix.ECONNRESET)
159 if diff := cmp.Diff(want, err); diff != "" {
160 t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
161 }
162 }
163
164 func TestLinuxOpenPIDFD(t *testing.T) {
165
166
167 fd, err := unix.PidfdOpen(1, unix.PIDFD_NONBLOCK)
168 if err != nil {
169 t.Fatalf("failed to open pidfd for init: %v", err)
170 }
171
172 c, err := socket.New(fd, "pidfd")
173 if err != nil {
174 t.Fatalf("failed to open Conn for pidfd: %v", err)
175 }
176 _ = c.Close()
177 }
178
View as plain text