...

Source file src/golang.org/x/net/internal/socket/socket_test.go

Documentation: golang.org/x/net/internal/socket

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || windows || zos
     6  
     7  package socket_test
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"net"
    13  	"os"
    14  	"os/exec"
    15  	"path/filepath"
    16  	"runtime"
    17  	"strings"
    18  	"syscall"
    19  	"testing"
    20  
    21  	"golang.org/x/net/internal/socket"
    22  	"golang.org/x/net/nettest"
    23  )
    24  
    25  func TestSocket(t *testing.T) {
    26  	t.Run("Option", func(t *testing.T) {
    27  		testSocketOption(t, &socket.Option{Level: syscall.SOL_SOCKET, Name: syscall.SO_RCVBUF, Len: 4})
    28  	})
    29  }
    30  
    31  func testSocketOption(t *testing.T, so *socket.Option) {
    32  	c, err := nettest.NewLocalPacketListener("udp")
    33  	if err != nil {
    34  		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
    35  	}
    36  	defer c.Close()
    37  	cc, err := socket.NewConn(c.(net.Conn))
    38  	if err != nil {
    39  		t.Fatal(err)
    40  	}
    41  	const N = 2048
    42  	if err := so.SetInt(cc, N); err != nil {
    43  		t.Fatal(err)
    44  	}
    45  	n, err := so.GetInt(cc)
    46  	if err != nil {
    47  		t.Fatal(err)
    48  	}
    49  	if n < N {
    50  		t.Fatalf("got %d; want greater than or equal to %d", n, N)
    51  	}
    52  }
    53  
    54  type mockControl struct {
    55  	Level int
    56  	Type  int
    57  	Data  []byte
    58  }
    59  
    60  func TestControlMessage(t *testing.T) {
    61  	switch runtime.GOOS {
    62  	case "windows":
    63  		t.Skipf("not supported on %s", runtime.GOOS)
    64  	}
    65  
    66  	for _, tt := range []struct {
    67  		cs []mockControl
    68  	}{
    69  		{
    70  			[]mockControl{
    71  				{Level: 1, Type: 1},
    72  			},
    73  		},
    74  		{
    75  			[]mockControl{
    76  				{Level: 2, Type: 2, Data: []byte{0xfe}},
    77  			},
    78  		},
    79  		{
    80  			[]mockControl{
    81  				{Level: 3, Type: 3, Data: []byte{0xfe, 0xff, 0xff, 0xfe}},
    82  			},
    83  		},
    84  		{
    85  			[]mockControl{
    86  				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
    87  			},
    88  		},
    89  		{
    90  			[]mockControl{
    91  				{Level: 4, Type: 4, Data: []byte{0xfe, 0xff, 0xff, 0xfe, 0xfe, 0xff, 0xff, 0xfe}},
    92  				{Level: 2, Type: 2, Data: []byte{0xfe}},
    93  			},
    94  		},
    95  	} {
    96  		var w []byte
    97  		var tailPadLen int
    98  		mm := socket.NewControlMessage([]int{0})
    99  		for i, c := range tt.cs {
   100  			m := socket.NewControlMessage([]int{len(c.Data)})
   101  			l := len(m) - len(mm)
   102  			if i == len(tt.cs)-1 && l > len(c.Data) {
   103  				tailPadLen = l - len(c.Data)
   104  			}
   105  			w = append(w, m...)
   106  		}
   107  
   108  		var err error
   109  		ww := make([]byte, len(w))
   110  		copy(ww, w)
   111  		m := socket.ControlMessage(ww)
   112  		for _, c := range tt.cs {
   113  			if err = m.MarshalHeader(c.Level, c.Type, len(c.Data)); err != nil {
   114  				t.Fatalf("(%v).MarshalHeader() = %v", tt.cs, err)
   115  			}
   116  			copy(m.Data(len(c.Data)), c.Data)
   117  			m = m.Next(len(c.Data))
   118  		}
   119  		m = socket.ControlMessage(w)
   120  		for _, c := range tt.cs {
   121  			m, err = m.Marshal(c.Level, c.Type, c.Data)
   122  			if err != nil {
   123  				t.Fatalf("(%v).Marshal() = %v", tt.cs, err)
   124  			}
   125  		}
   126  		if !bytes.Equal(ww, w) {
   127  			t.Fatalf("got %#v; want %#v", ww, w)
   128  		}
   129  
   130  		ws := [][]byte{w}
   131  		if tailPadLen > 0 {
   132  			// Test a message with no tail padding.
   133  			nopad := w[:len(w)-tailPadLen]
   134  			ws = append(ws, [][]byte{nopad}...)
   135  		}
   136  		for _, w := range ws {
   137  			ms, err := socket.ControlMessage(w).Parse()
   138  			if err != nil {
   139  				t.Fatalf("(%v).Parse() = %v", tt.cs, err)
   140  			}
   141  			for i, m := range ms {
   142  				lvl, typ, dataLen, err := m.ParseHeader()
   143  				if err != nil {
   144  					t.Fatalf("(%v).ParseHeader() = %v", tt.cs, err)
   145  				}
   146  				if lvl != tt.cs[i].Level || typ != tt.cs[i].Type || dataLen != len(tt.cs[i].Data) {
   147  					t.Fatalf("%v: got %d, %d, %d; want %d, %d, %d", tt.cs[i], lvl, typ, dataLen, tt.cs[i].Level, tt.cs[i].Type, len(tt.cs[i].Data))
   148  				}
   149  			}
   150  		}
   151  	}
   152  }
   153  
   154  func TestUDP(t *testing.T) {
   155  	switch runtime.GOOS {
   156  	case "windows":
   157  		t.Skipf("not supported on %s", runtime.GOOS)
   158  	}
   159  
   160  	c, err := nettest.NewLocalPacketListener("udp")
   161  	if err != nil {
   162  		t.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
   163  	}
   164  	defer c.Close()
   165  	// test that wrapped connections work with NewConn too
   166  	type wrappedConn struct{ *net.UDPConn }
   167  	cc, err := socket.NewConn(&wrappedConn{c.(*net.UDPConn)})
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  
   172  	// create a dialed connection talking (only) to c/cc
   173  	cDialed, err := net.Dial("udp", c.LocalAddr().String())
   174  	if err != nil {
   175  		t.Fatal(err)
   176  	}
   177  	ccDialed, err := socket.NewConn(cDialed)
   178  	if err != nil {
   179  		t.Fatal(err)
   180  	}
   181  
   182  	const data = "HELLO-R-U-THERE"
   183  	messageTests := []struct {
   184  		name string
   185  		conn *socket.Conn
   186  		dest net.Addr
   187  	}{
   188  		{
   189  			name: "Message",
   190  			conn: cc,
   191  			dest: c.LocalAddr(),
   192  		},
   193  		{
   194  			name: "Message-dialed",
   195  			conn: ccDialed,
   196  			dest: nil,
   197  		},
   198  	}
   199  	for _, tt := range messageTests {
   200  		t.Run(tt.name, func(t *testing.T) {
   201  			wm := socket.Message{
   202  				Buffers: bytes.SplitAfter([]byte(data), []byte("-")),
   203  				Addr:    tt.dest,
   204  			}
   205  			if err := tt.conn.SendMsg(&wm, 0); err != nil {
   206  				t.Fatal(err)
   207  			}
   208  			b := make([]byte, 32)
   209  			rm := socket.Message{
   210  				Buffers: [][]byte{b[:1], b[1:3], b[3:7], b[7:11], b[11:]},
   211  			}
   212  			if err := cc.RecvMsg(&rm, 0); err != nil {
   213  				t.Fatal(err)
   214  			}
   215  			received := string(b[:rm.N])
   216  			if received != data {
   217  				t.Fatalf("Roundtrip SendMsg/RecvMsg got %q; want %q", received, data)
   218  			}
   219  		})
   220  	}
   221  
   222  	switch runtime.GOOS {
   223  	case "android", "linux":
   224  		messagesTests := []struct {
   225  			name string
   226  			conn *socket.Conn
   227  			dest net.Addr
   228  		}{
   229  			{
   230  				name: "Messages",
   231  				conn: cc,
   232  				dest: c.LocalAddr(),
   233  			},
   234  			{
   235  				name: "Messages-dialed",
   236  				conn: ccDialed,
   237  				dest: nil,
   238  			},
   239  		}
   240  		for _, tt := range messagesTests {
   241  			t.Run(tt.name, func(t *testing.T) {
   242  				wmbs := bytes.SplitAfter([]byte(data), []byte("-"))
   243  				wms := []socket.Message{
   244  					{Buffers: wmbs[:1], Addr: tt.dest},
   245  					{Buffers: wmbs[1:], Addr: tt.dest},
   246  				}
   247  				n, err := tt.conn.SendMsgs(wms, 0)
   248  				if err != nil {
   249  					t.Fatal(err)
   250  				}
   251  				if n != len(wms) {
   252  					t.Fatalf("SendMsgs(%#v) != %d; want %d", wms, n, len(wms))
   253  				}
   254  				rmbs := [][]byte{make([]byte, 32), make([]byte, 32)}
   255  				rms := []socket.Message{
   256  					{Buffers: [][]byte{rmbs[0]}},
   257  					{Buffers: [][]byte{rmbs[1][:1], rmbs[1][1:3], rmbs[1][3:7], rmbs[1][7:11], rmbs[1][11:]}},
   258  				}
   259  				nrecv := 0
   260  				for nrecv < len(rms) {
   261  					n, err := cc.RecvMsgs(rms[nrecv:], 0)
   262  					if err != nil {
   263  						t.Fatal(err)
   264  					}
   265  					nrecv += n
   266  				}
   267  				received0, received1 := string(rmbs[0][:rms[0].N]), string(rmbs[1][:rms[1].N])
   268  				assembled := received0 + received1
   269  				assembledReordered := received1 + received0
   270  				if assembled != data && assembledReordered != data {
   271  					t.Fatalf("Roundtrip SendMsgs/RecvMsgs got %q / %q; want %q", assembled, assembledReordered, data)
   272  				}
   273  			})
   274  		}
   275  		t.Run("Messages-undialed-no-dst", func(t *testing.T) {
   276  			// sending without destination address should fail.
   277  			// This checks that the internally recycled buffers are reset correctly.
   278  			data := []byte("HELLO-R-U-THERE")
   279  			wmbs := bytes.SplitAfter(data, []byte("-"))
   280  			wms := []socket.Message{
   281  				{Buffers: wmbs[:1], Addr: nil},
   282  				{Buffers: wmbs[1:], Addr: nil},
   283  			}
   284  			n, err := cc.SendMsgs(wms, 0)
   285  			if n != 0 && err == nil {
   286  				t.Fatal("expected error, destination address required")
   287  			}
   288  		})
   289  	}
   290  
   291  	// The behavior of transmission for zero byte paylaod depends
   292  	// on each platform implementation. Some may transmit only
   293  	// protocol header and options, other may transmit nothing.
   294  	// We test only that SendMsg and SendMsgs will not crash with
   295  	// empty buffers.
   296  	wm := socket.Message{
   297  		Buffers: [][]byte{{}},
   298  		Addr:    c.LocalAddr(),
   299  	}
   300  	cc.SendMsg(&wm, 0)
   301  	wms := []socket.Message{
   302  		{Buffers: [][]byte{{}}, Addr: c.LocalAddr()},
   303  	}
   304  	cc.SendMsgs(wms, 0)
   305  }
   306  
   307  func BenchmarkUDP(b *testing.B) {
   308  	c, err := nettest.NewLocalPacketListener("udp")
   309  	if err != nil {
   310  		b.Skipf("not supported on %s/%s: %v", runtime.GOOS, runtime.GOARCH, err)
   311  	}
   312  	defer c.Close()
   313  	cc, err := socket.NewConn(c.(net.Conn))
   314  	if err != nil {
   315  		b.Fatal(err)
   316  	}
   317  	data := []byte("HELLO-R-U-THERE")
   318  	wm := socket.Message{
   319  		Buffers: [][]byte{data},
   320  		Addr:    c.LocalAddr(),
   321  	}
   322  	rm := socket.Message{
   323  		Buffers: [][]byte{make([]byte, 128)},
   324  		OOB:     make([]byte, 128),
   325  	}
   326  
   327  	for M := 1; M <= 1<<9; M = M << 1 {
   328  		b.Run(fmt.Sprintf("Iter-%d", M), func(b *testing.B) {
   329  			for i := 0; i < b.N; i++ {
   330  				for j := 0; j < M; j++ {
   331  					if err := cc.SendMsg(&wm, 0); err != nil {
   332  						b.Fatal(err)
   333  					}
   334  					if err := cc.RecvMsg(&rm, 0); err != nil {
   335  						b.Fatal(err)
   336  					}
   337  				}
   338  			}
   339  		})
   340  		switch runtime.GOOS {
   341  		case "android", "linux":
   342  			wms := make([]socket.Message, M)
   343  			for i := range wms {
   344  				wms[i].Buffers = [][]byte{data}
   345  				wms[i].Addr = c.LocalAddr()
   346  			}
   347  			rms := make([]socket.Message, M)
   348  			for i := range rms {
   349  				rms[i].Buffers = [][]byte{make([]byte, 128)}
   350  				rms[i].OOB = make([]byte, 128)
   351  			}
   352  			b.Run(fmt.Sprintf("Batch-%d", M), func(b *testing.B) {
   353  				for i := 0; i < b.N; i++ {
   354  					if _, err := cc.SendMsgs(wms, 0); err != nil {
   355  						b.Fatal(err)
   356  					}
   357  					if _, err := cc.RecvMsgs(rms, 0); err != nil {
   358  						b.Fatal(err)
   359  					}
   360  				}
   361  			})
   362  		}
   363  	}
   364  }
   365  
   366  func TestRace(t *testing.T) {
   367  	tests := []string{
   368  		`
   369  package main
   370  import (
   371  	"log"
   372  	"net"
   373  
   374  	"golang.org/x/net/ipv4"
   375  )
   376  
   377  var g byte
   378  
   379  func main() {
   380  	c, err := net.ListenPacket("udp", "127.0.0.1:0")
   381  	if err != nil {
   382  		log.Fatalf("ListenPacket: %v", err)
   383  	}
   384  	cc := ipv4.NewPacketConn(c)
   385  	sync := make(chan bool)
   386  	src := make([]byte, 100)
   387  	dst := make([]byte, 100)
   388  	go func() {
   389  		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
   390  			log.Fatalf("WriteTo: %v", err)
   391  		}
   392  	}()
   393  	go func() {
   394  		if _, _, _, err := cc.ReadFrom(dst); err != nil {
   395  			log.Fatalf("ReadFrom: %v", err)
   396  		}
   397  		sync <- true
   398  	}()
   399  	g = dst[0]
   400  	<-sync
   401  }
   402  `,
   403  		`
   404  package main
   405  import (
   406  	"log"
   407  	"net"
   408  
   409  	"golang.org/x/net/ipv4"
   410  )
   411  
   412  func main() {
   413  	c, err := net.ListenPacket("udp", "127.0.0.1:0")
   414  	if err != nil {
   415  		log.Fatalf("ListenPacket: %v", err)
   416  	}
   417  	cc := ipv4.NewPacketConn(c)
   418  	sync := make(chan bool)
   419  	src := make([]byte, 100)
   420  	dst := make([]byte, 100)
   421  	go func() {
   422  		if _, err := cc.WriteTo(src, nil, c.LocalAddr()); err != nil {
   423  			log.Fatalf("WriteTo: %v", err)
   424  		}
   425  		sync <- true
   426  	}()
   427  	src[0] = 0
   428  	go func() {
   429  		if _, _, _, err := cc.ReadFrom(dst); err != nil {
   430  			log.Fatalf("ReadFrom: %v", err)
   431  		}
   432  	}()
   433  	<-sync
   434  }
   435  `,
   436  	}
   437  	platforms := map[string]bool{
   438  		"linux/amd64":   true,
   439  		"linux/ppc64le": true,
   440  		"linux/arm64":   true,
   441  	}
   442  	if !platforms[runtime.GOOS+"/"+runtime.GOARCH] {
   443  		t.Skip("skipping test on non-race-enabled host.")
   444  	}
   445  	if runtime.Compiler == "gccgo" {
   446  		t.Skip("skipping race test when built with gccgo")
   447  	}
   448  	dir, err := os.MkdirTemp("", "testrace")
   449  	if err != nil {
   450  		t.Fatalf("failed to create temp directory: %v", err)
   451  	}
   452  	defer os.RemoveAll(dir)
   453  	goBinary := filepath.Join(runtime.GOROOT(), "bin", "go")
   454  	t.Logf("%s version", goBinary)
   455  	got, err := exec.Command(goBinary, "version").CombinedOutput()
   456  	if len(got) > 0 {
   457  		t.Logf("%s", got)
   458  	}
   459  	if err != nil {
   460  		t.Fatalf("go version failed: %v", err)
   461  	}
   462  	for i, test := range tests {
   463  		t.Run(fmt.Sprintf("test %d", i), func(t *testing.T) {
   464  			src := filepath.Join(dir, fmt.Sprintf("test%d.go", i))
   465  			if err := os.WriteFile(src, []byte(test), 0644); err != nil {
   466  				t.Fatalf("failed to write file: %v", err)
   467  			}
   468  			t.Logf("%s run -race %s", goBinary, src)
   469  			got, err := exec.Command(goBinary, "run", "-race", src).CombinedOutput()
   470  			if len(got) > 0 {
   471  				t.Logf("%s", got)
   472  			}
   473  			if strings.Contains(string(got), "-race requires cgo") {
   474  				t.Log("CGO is not enabled so can't use -race")
   475  			} else if !strings.Contains(string(got), "WARNING: DATA RACE") {
   476  				t.Errorf("race not detected for test %d: err:%v", i, err)
   477  			}
   478  		})
   479  	}
   480  }
   481  

View as plain text