1
2
3
4
5
6
7 package socket_test
8
9 import (
10 "bytes"
11 "fmt"
12 "net"
13 "os"
14 "os/exec"
15 "path/filepath"
16 "runtime"
17 "strings"
18 "syscall"
19 "testing"
20
21 "golang.org/x/net/internal/socket"
22 "golang.org/x/net/nettest"
23 )
24
25 func TestSocket(t *testing.T) {
26 t.Run("Option", func(t *testing.T) {
27 testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
28 })
29 }
30
31 func testSocketOption(t *testing.T, so *socket.Option) {
32 c, err := nettest.NewLocalPacketListener("udp")
33 if err != nil {
34 t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
35 }
36 defer c.Close()
37 cc, err := socket.NewConn(c.(net.Conn))
38 if err != nil {
39 t.Fatal(err)
40 }
41 const N = 2048
42 if err := so.SetInt(cc, N); err != nil {
43 t.Fatal(err)
44 }
45 n, err := so.GetInt(cc)
46 if err != nil {
47 t.Fatal(err)
48 }
49 if n < N {
50 t.Fatalf("got %d; want greater than or equal to %d", n, N)
51 }
52 }
53
54 type mockControl struct {
55 Level int
56 Type int
57 Data []byte
58 }
59
60 func TestControlMessage(t *testing.T) {
61 switch runtime.GOOS {
62 case "windows":
63 t.Skipf("not supported on %s", runtime.GOOS)
64 }
65
66 for _, tt := range []struct {
67 cs []mockControl
68 }{
69 {
70 []mockControl{
71 {Level: 1, Type: 1},
72 },
73 },
74 {
75 []mockControl{
76 {Level: 2, Type: 2, Data: []byte{0xfe}},
77 },
78 },
79 {
80 []mockControl{
81 {Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
82 },
83 },
84 {
85 []mockControl{
86 {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
87 },
88 },
89 {
90 []mockControl{
91 {Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
92 {Level: 2, Type: 2, Data: []byte{0xfe}},
93 },
94 },
95 } {
96 var w []byte
97 var tailPadLen int
98 mm := socket.NewControlMessage([]int{0})
99 for i, c := range tt.cs {
100 m := socket.NewControlMessage([]int{len(c.Data)})
101 l := len(m) - len(mm)
102 if i == len(tt.cs)-1 && l > len(c.Data) {
103 tailPadLen = l - len(c.Data)
104 }
105 w = append(w, m...)
106 }
107
108 var err error
109 ww := make([]byte, len(w))
110 copy(ww, w)
111 m := socket.ControlMessage(ww)
112 for _, c := range tt.cs {
113 if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
114 t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
115 }
116 copy(m.Data(len(c.Data)), c.Data)
117 m = m.Next(len(c.Data))
118 }
119 m = socket.ControlMessage(w)
120 for _, c := range tt.cs {
121 m, err = m.Marshal(c.Level, c.Type, c.Data)
122 if err != nil {
123 t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
124 }
125 }
126 if !bytes.Equal(ww, w) {
127 t.Fatalf("got %#v; want %#v", ww, w)
128 }
129
130 ws := [][]byte{w}
131 if tailPadLen > 0 {
132
133 nopad := w[:len(w)-tailPadLen]
134 ws = append(ws, [][]byte{nopad}...)
135 }
136 for _, w := range ws {
137 ms, err := socket.ControlMessage(w).Parse()
138 if err != nil {
139 t.Fatalf("(%v).Parse() = %v", tt.cs, err)
140 }
141 for i, m := range ms {
142 lvl, typ, dataLen, err := m.ParseHeader()
143 if err != nil {
144 t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
145 }
146 if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
147 t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
148 }
149 }
150 }
151 }
152 }
153
154 func TestUDP(t *testing.T) {
155 switch runtime.GOOS {
156 case "windows":
157 t.Skipf("not supported on %s", runtime.GOOS)
158 }
159
160 c, err := nettest.NewLocalPacketListener("udp")
161 if err != nil {
162 t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
163 }
164 defer c.Close()
165
166 type wrappedConn struct{ *net.UDPConn }
167 cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)})
168 if err != nil {
169 t.Fatal(err)
170 }
171
172
173 cDialed, err := net.Dial("udp", c.LocalAddr().String())
174 if err != nil {
175 t.Fatal(err)
176 }
177 ccDialed, err := socket.NewConn(cDialed)
178 if err != nil {
179 t.Fatal(err)
180 }
181
182 const data = "HELLO-R-U-THERE"
183 messageTests := []struct {
184 name string
185 conn *socket.Conn
186 dest net.Addr
187 }{
188 {
189 name: "Message",
190 conn: cc,
191 dest: c.LocalAddr(),
192 },
193 {
194 name: "Message-dialed",
195 conn: ccDialed,
196 dest: nil,
197 },
198 }
199 for _, tt := range messageTests {
200 t.Run(tt.name, func(t *testing.T) {
201 wm := socket.Message{
202 Buffers: bytes.SplitAfter([]byte(data), []byte("-")),
203 Addr: tt.dest,
204 }
205 if err := tt.conn.SendMsg(&wm, 0); err != nil {
206 t.Fatal(err)
207 }
208 b := make([]byte, 32)
209 rm := socket.Message{
210 Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
211 }
212 if err := cc.RecvMsg(&rm, 0); err != nil {
213 t.Fatal(err)
214 }
215 received := string(b[:rm.N])
216 if received != data {
217 t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data)
218 }
219 })
220 }
221
222 switch runtime.GOOS {
223 case "android", "linux":
224 messagesTests := []struct {
225 name string
226 conn *socket.Conn
227 dest net.Addr
228 }{
229 {
230 name: "Messages",
231 conn: cc,
232 dest: c.LocalAddr(),
233 },
234 {
235 name: "Messages-dialed",
236 conn: ccDialed,
237 dest: nil,
238 },
239 }
240 for _, tt := range messagesTests {
241 t.Run(tt.name, func(t *testing.T) {
242 wmbs := bytes.SplitAfter([]byte(data), []byte("-"))
243 wms := []socket.Message{
244 {Buffers: wmbs[:1], Addr: tt.dest},
245 {Buffers: wmbs[1:], Addr: tt.dest},
246 }
247 n, err := tt.conn.SendMsgs(wms, 0)
248 if err != nil {
249 t.Fatal(err)
250 }
251 if n != len(wms) {
252 t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms))
253 }
254 rmbs := [][]byte{make([]byte, 32), make([]byte, 32)}
255 rms := []socket.Message{
256 {Buffers: [][]byte{rmbs[0]}},
257 {Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}},
258 }
259 nrecv := 0
260 for nrecv < len(rms) {
261 n, err := cc.RecvMsgs(rms[nrecv:], 0)
262 if err != nil {
263 t.Fatal(err)
264 }
265 nrecv += n
266 }
267 received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N])
268 assembled := received0 + received1
269 assembledReordered := received1 + received0
270 if assembled != data && assembledReordered != data {
271 t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data)
272 }
273 })
274 }
275 t.Run("Messages-undialed-no-dst", func(t *testing.T) {
276
277
278 data := []byte("HELLO-R-U-THERE")
279 wmbs := bytes.SplitAfter(data, []byte("-"))
280 wms := []socket.Message{
281 {Buffers: wmbs[:1], Addr: nil},
282 {Buffers: wmbs[1:], Addr: nil},
283 }
284 n, err := cc.SendMsgs(wms, 0)
285 if n != 0 && err == nil {
286 t.Fatal("expected error, destination address required")
287 }
288 })
289 }
290
291
292
293
294
295
296 wm := socket.Message{
297 Buffers: [][]byte{{}},
298 Addr: c.LocalAddr(),
299 }
300 cc.SendMsg(&wm, 0)
301 wms := []socket.Message{
302 {Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
303 }
304 cc.SendMsgs(wms, 0)
305 }
306
307 func BenchmarkUDP(b *testing.B) {
308 c, err := nettest.NewLocalPacketListener("udp")
309 if err != nil {
310 b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
311 }
312 defer c.Close()
313 cc, err := socket.NewConn(c.(net.Conn))
314 if err != nil {
315 b.Fatal(err)
316 }
317 data := []byte("HELLO-R-U-THERE")
318 wm := socket.Message{
319 Buffers: [][]byte{data},
320 Addr: c.LocalAddr(),
321 }
322 rm := socket.Message{
323 Buffers: [][]byte{make([]byte, 128)},
324 OOB: make([]byte, 128),
325 }
326
327 for M := 1; M <= 1<<9; M = M << 1 {
328 b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
329 for i := 0; i < b.N; i++ {
330 for j := 0; j < M; j++ {
331 if err := cc.SendMsg(&wm, 0); err != nil {
332 b.Fatal(err)
333 }
334 if err := cc.RecvMsg(&rm, 0); err != nil {
335 b.Fatal(err)
336 }
337 }
338 }
339 })
340 switch runtime.GOOS {
341 case "android", "linux":
342 wms := make([]socket.Message, M)
343 for i := range wms {
344 wms[i].Buffers = [][]byte{data}
345 wms[i].Addr = c.LocalAddr()
346 }
347 rms := make([]socket.Message, M)
348 for i := range rms {
349 rms[i].Buffers = [][]byte{make([]byte, 128)}
350 rms[i].OOB = make([]byte, 128)
351 }
352 b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
353 for i := 0; i < b.N; i++ {
354 if _, err := cc.SendMsgs(wms, 0); err != nil {
355 b.Fatal(err)
356 }
357 if _, err := cc.RecvMsgs(rms, 0); err != nil {
358 b.Fatal(err)
359 }
360 }
361 })
362 }
363 }
364 }
365
366 func TestRace(t *testing.T) {
367 tests := []string{
368 `
369 package main
370 import (
371 "log"
372 "net"
373
374 "golang.org/x/net/ipv4"
375 )
376
377 var g byte
378
379 func main() {
380 c, err := net.ListenPacket("udp", "127.0.0.1:0")
381 if err != nil {
382 log.Fatalf("ListenPacket: %v", err)
383 }
384 cc := ipv4.NewPacketConn(c)
385 sync := make(chan bool)
386 src := make([]byte, 100)
387 dst := make([]byte, 100)
388 go func() {
389 if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
390 log.Fatalf("WriteTo: %v", err)
391 }
392 }()
393 go func() {
394 if _, _, _, err := cc.ReadFrom(dst); err != nil {
395 log.Fatalf("ReadFrom: %v", err)
396 }
397 sync <- true
398 }()
399 g = dst[0]
400 <-sync
401 }
402 `,
403 `
404 package main
405 import (
406 "log"
407 "net"
408
409 "golang.org/x/net/ipv4"
410 )
411
412 func main() {
413 c, err := net.ListenPacket("udp", "127.0.0.1:0")
414 if err != nil {
415 log.Fatalf("ListenPacket: %v", err)
416 }
417 cc := ipv4.NewPacketConn(c)
418 sync := make(chan bool)
419 src := make([]byte, 100)
420 dst := make([]byte, 100)
421 go func() {
422 if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
423 log.Fatalf("WriteTo: %v", err)
424 }
425 sync <- true
426 }()
427 src[0] = 0
428 go func() {
429 if _, _, _, err := cc.ReadFrom(dst); err != nil {
430 log.Fatalf("ReadFrom: %v", err)
431 }
432 }()
433 <-sync
434 }
435 `,
436 }
437 platforms := map[string]bool{
438 "linux/amd64": true,
439 "linux/ppc64le": true,
440 "linux/arm64": true,
441 }
442 if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
443 t.Skip("skipping test on non-race-enabled host.")
444 }
445 if runtime.Compiler == "gccgo" {
446 t.Skip("skipping race test when built with gccgo")
447 }
448 dir, err := os.MkdirTemp("", "testrace")
449 if err != nil {
450 t.Fatalf("failed to create temp directory: %v", err)
451 }
452 defer os.RemoveAll(dir)
453 goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
454 t.Logf("%s version", goBinary)
455 got, err := exec.Command(goBinary, "version").CombinedOutput()
456 if len(got) > 0 {
457 t.Logf("%s", got)
458 }
459 if err != nil {
460 t.Fatalf("go version failed: %v", err)
461 }
462 for i, test := range tests {
463 t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
464 src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
465 if err := os.WriteFile(src, []byte(test), 0644); err != nil {
466 t.Fatalf("failed to write file: %v", err)
467 }
468 t.Logf("%s run -race %s", goBinary, src)
469 got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
470 if len(got) > 0 {
471 t.Logf("%s", got)
472 }
473 if strings.Contains(string(got), "-race requires cgo") {
474 t.Log("CGO is not enabled so can't use -race")
475 } else if !strings.Contains(string(got), "WARNING: DATA RACE") {
476 t.Errorf("race not detected for test %d: err:%v", i, err)
477 }
478 })
479 }
480 }
481
View as plain text