...

Source file src/github.com/Microsoft/go-winio/pipe_test.go

Documentation: github.com/Microsoft/go-winio

     1  //go:build windows
     2  // +build windows
     3  
     4  package winio
     5  
     6  import (
     7  	"bufio"
     8  	"bytes"
     9  	"context"
    10  	"errors"
    11  	"io"
    12  	"net"
    13  	"sync"
    14  	"syscall"
    15  	"testing"
    16  	"time"
    17  	"unsafe"
    18  
    19  	"golang.org/x/sys/windows"
    20  )
    21  
    22  var testPipeName = `\\.\pipe\winiotestpipe`
    23  
    24  var aLongTimeAgo = time.Unix(1, 0)
    25  
    26  func TestDialUnknownFailsImmediately(t *testing.T) {
    27  	_, err := DialPipe(testPipeName, nil)
    28  	if !errors.Is(err, syscall.ENOENT) {
    29  		t.Fatalf("expected ENOENT got %v", err)
    30  	}
    31  }
    32  
    33  func TestDialListenerTimesOut(t *testing.T) {
    34  	l, err := ListenPipe(testPipeName, nil)
    35  	if err != nil {
    36  		t.Fatal(err)
    37  	}
    38  	defer l.Close()
    39  	var d = 10 * time.Millisecond
    40  	_, err = DialPipe(testPipeName, &d)
    41  	if !errors.Is(err, ErrTimeout) {
    42  		t.Fatalf("expected ErrTimeout, got %v", err)
    43  	}
    44  }
    45  
    46  func TestDialContextListenerTimesOut(t *testing.T) {
    47  	l, err := ListenPipe(testPipeName, nil)
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  	defer l.Close()
    52  	var d = 10 * time.Millisecond
    53  	ctx, cancel := context.WithTimeout(context.Background(), d)
    54  	defer cancel()
    55  	_, err = DialPipeContext(ctx, testPipeName)
    56  	if !errors.Is(err, context.DeadlineExceeded) {
    57  		t.Fatalf("expected context.DeadlineExceeded, got %v", err)
    58  	}
    59  }
    60  
    61  func TestDialListenerGetsCancelled(t *testing.T) {
    62  	ctx, cancel := context.WithCancel(context.Background())
    63  	l, err := ListenPipe(testPipeName, nil)
    64  	if err != nil {
    65  		t.Fatal(err)
    66  	}
    67  	ch := make(chan error)
    68  	defer l.Close()
    69  	go func(ctx context.Context, ch chan error) {
    70  		_, err := DialPipeContext(ctx, testPipeName)
    71  		ch <- err
    72  	}(ctx, ch)
    73  	time.Sleep(time.Millisecond * 30)
    74  	cancel()
    75  	err = <-ch
    76  	if !errors.Is(err, context.Canceled) {
    77  		t.Fatalf("expected context.Canceled, got %v", err)
    78  	}
    79  }
    80  
    81  func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
    82  	c := PipeConfig{
    83  		SecurityDescriptor: "D:P(A;;0x1200FF;;;WD)",
    84  	}
    85  	l, err := ListenPipe(testPipeName, &c)
    86  	if err != nil {
    87  		t.Fatal(err)
    88  	}
    89  	defer l.Close()
    90  	_, err = DialPipe(testPipeName, nil)
    91  	if !errors.Is(err, syscall.ERROR_ACCESS_DENIED) {
    92  		t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
    93  	}
    94  }
    95  
    96  func getConnection(cfg *PipeConfig) (client net.Conn, server net.Conn, err error) {
    97  	l, err := ListenPipe(testPipeName, cfg)
    98  	if err != nil {
    99  		return nil, nil, err
   100  	}
   101  	defer l.Close()
   102  
   103  	type response struct {
   104  		c   net.Conn
   105  		err error
   106  	}
   107  	ch := make(chan response)
   108  	go func() {
   109  		c, err := l.Accept()
   110  		ch <- response{c, err}
   111  	}()
   112  
   113  	c, err := DialPipe(testPipeName, nil)
   114  	if err != nil {
   115  		return client, server, err
   116  	}
   117  
   118  	r := <-ch
   119  	if err = r.err; err != nil {
   120  		c.Close()
   121  		return nil, nil, err
   122  	}
   123  
   124  	return c, r.c, nil
   125  }
   126  
   127  func TestReadTimeout(t *testing.T) {
   128  	c, s, err := getConnection(nil)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  	defer c.Close()
   133  	defer s.Close()
   134  
   135  	_ = c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
   136  
   137  	buf := make([]byte, 10)
   138  	_, err = c.Read(buf)
   139  	if !errors.Is(err, ErrTimeout) {
   140  		t.Fatalf("expected ErrTimeout, got %v", err)
   141  	}
   142  }
   143  
   144  func server(l net.Listener, ch chan int) {
   145  	c, err := l.Accept()
   146  	if err != nil {
   147  		panic(err)
   148  	}
   149  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   150  	s, err := rw.ReadString('\n')
   151  	if err != nil {
   152  		panic(err)
   153  	}
   154  	_, err = rw.WriteString("got " + s)
   155  	if err != nil {
   156  		panic(err)
   157  	}
   158  	err = rw.Flush()
   159  	if err != nil {
   160  		panic(err)
   161  	}
   162  	c.Close()
   163  	ch <- 1
   164  }
   165  
   166  func TestFullListenDialReadWrite(t *testing.T) {
   167  	l, err := ListenPipe(testPipeName, nil)
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  	defer l.Close()
   172  
   173  	ch := make(chan int)
   174  	go server(l, ch)
   175  
   176  	c, err := DialPipe(testPipeName, nil)
   177  	if err != nil {
   178  		t.Fatal(err)
   179  	}
   180  	defer c.Close()
   181  
   182  	rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
   183  	_, err = rw.WriteString("hello world\n")
   184  	if err != nil {
   185  		t.Fatal(err)
   186  	}
   187  	err = rw.Flush()
   188  	if err != nil {
   189  		t.Fatal(err)
   190  	}
   191  
   192  	s, err := rw.ReadString('\n')
   193  	if err != nil {
   194  		t.Fatal(err)
   195  	}
   196  	ms := "got hello world\n"
   197  	if s != ms {
   198  		t.Errorf("expected '%s', got '%s'", ms, s)
   199  	}
   200  
   201  	<-ch
   202  }
   203  
   204  func TestCloseAbortsListen(t *testing.T) {
   205  	l, err := ListenPipe(testPipeName, nil)
   206  	if err != nil {
   207  		t.Fatal(err)
   208  	}
   209  
   210  	ch := make(chan error)
   211  	go func() {
   212  		_, err := l.Accept()
   213  		ch <- err
   214  	}()
   215  
   216  	time.Sleep(30 * time.Millisecond)
   217  	l.Close()
   218  
   219  	err = <-ch
   220  	if !errors.Is(err, ErrPipeListenerClosed) {
   221  		t.Fatalf("expected ErrPipeListenerClosed, got %v", err)
   222  	}
   223  }
   224  
   225  func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
   226  	b := make([]byte, 10)
   227  	w.Close()
   228  	n, err := r.Read(b)
   229  	if n > 0 {
   230  		t.Errorf("unexpected byte count %d", n)
   231  	}
   232  	if err != io.EOF {
   233  		t.Errorf("expected EOF: %v", err)
   234  	}
   235  }
   236  
   237  func TestCloseClientEOFServer(t *testing.T) {
   238  	c, s, err := getConnection(nil)
   239  	if err != nil {
   240  		t.Fatal(err)
   241  	}
   242  	defer c.Close()
   243  	defer s.Close()
   244  	ensureEOFOnClose(t, c, s)
   245  }
   246  
   247  func TestCloseServerEOFClient(t *testing.T) {
   248  	c, s, err := getConnection(nil)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	defer c.Close()
   253  	defer s.Close()
   254  	ensureEOFOnClose(t, s, c)
   255  }
   256  
   257  func TestCloseWriteEOF(t *testing.T) {
   258  	cfg := &PipeConfig{
   259  		MessageMode: true,
   260  	}
   261  	c, s, err := getConnection(cfg)
   262  	if err != nil {
   263  		t.Fatal(err)
   264  	}
   265  	defer c.Close()
   266  	defer s.Close()
   267  
   268  	type closeWriter interface {
   269  		CloseWrite() error
   270  	}
   271  
   272  	err = c.(closeWriter).CloseWrite()
   273  	if err != nil {
   274  		t.Fatal(err)
   275  	}
   276  
   277  	b := make([]byte, 10)
   278  	_, err = s.Read(b)
   279  	if !errors.Is(err, io.EOF) {
   280  		t.Fatal(err)
   281  	}
   282  }
   283  
   284  func TestAcceptAfterCloseFails(t *testing.T) {
   285  	l, err := ListenPipe(testPipeName, nil)
   286  	if err != nil {
   287  		t.Fatal(err)
   288  	}
   289  	l.Close()
   290  	_, err = l.Accept()
   291  	if !errors.Is(err, ErrPipeListenerClosed) {
   292  		t.Fatalf("expected ErrPipeListenerClosed, got %v", err)
   293  	}
   294  }
   295  
   296  func TestDialTimesOutByDefault(t *testing.T) {
   297  	l, err := ListenPipe(testPipeName, nil)
   298  	if err != nil {
   299  		t.Fatal(err)
   300  	}
   301  	defer l.Close()
   302  	_, err = DialPipe(testPipeName, nil)
   303  	if !errors.Is(err, ErrTimeout) {
   304  		t.Fatalf("expected ErrTimeout, got %v", err)
   305  	}
   306  }
   307  
   308  func TestTimeoutPendingRead(t *testing.T) {
   309  	l, err := ListenPipe(testPipeName, nil)
   310  	if err != nil {
   311  		t.Fatal(err)
   312  	}
   313  	defer l.Close()
   314  
   315  	serverDone := make(chan struct{})
   316  
   317  	go func() {
   318  		s, err := l.Accept()
   319  		if err != nil {
   320  			t.Error(err)
   321  			return
   322  		}
   323  		time.Sleep(1 * time.Second)
   324  		s.Close()
   325  		close(serverDone)
   326  	}()
   327  
   328  	client, err := DialPipe(testPipeName, nil)
   329  	if err != nil {
   330  		t.Fatal(err)
   331  	}
   332  	defer client.Close()
   333  
   334  	clientErr := make(chan error)
   335  	go func() {
   336  		buf := make([]byte, 10)
   337  		_, err = client.Read(buf)
   338  		clientErr <- err
   339  	}()
   340  
   341  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
   342  	_ = client.SetReadDeadline(aLongTimeAgo)
   343  
   344  	select {
   345  	case err = <-clientErr:
   346  		if !errors.Is(err, ErrTimeout) {
   347  			t.Fatalf("expected ErrTimeout, got %v", err)
   348  		}
   349  	case <-time.After(100 * time.Millisecond):
   350  		t.Fatalf("timed out while waiting for read to cancel")
   351  		<-clientErr
   352  	}
   353  	<-serverDone
   354  }
   355  
   356  func TestTimeoutPendingWrite(t *testing.T) {
   357  	l, err := ListenPipe(testPipeName, nil)
   358  	if err != nil {
   359  		t.Fatal(err)
   360  	}
   361  	defer l.Close()
   362  
   363  	serverDone := make(chan struct{})
   364  
   365  	go func() {
   366  		s, err := l.Accept()
   367  		if err != nil {
   368  			t.Error(err)
   369  			return
   370  		}
   371  		time.Sleep(1 * time.Second)
   372  		s.Close()
   373  		close(serverDone)
   374  	}()
   375  
   376  	client, err := DialPipe(testPipeName, nil)
   377  	if err != nil {
   378  		t.Fatal(err)
   379  	}
   380  	defer client.Close()
   381  
   382  	clientErr := make(chan error)
   383  	go func() {
   384  		_, err = client.Write([]byte("this should timeout"))
   385  		clientErr <- err
   386  	}()
   387  
   388  	time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
   389  	_ = client.SetWriteDeadline(aLongTimeAgo)
   390  
   391  	select {
   392  	case err = <-clientErr:
   393  		if !errors.Is(err, ErrTimeout) {
   394  			t.Fatalf("expected ErrTimeout, got %v", err)
   395  		}
   396  	case <-time.After(100 * time.Millisecond):
   397  		t.Fatalf("timed out while waiting for write to cancel")
   398  		<-clientErr
   399  	}
   400  	<-serverDone
   401  }
   402  
   403  type CloseWriter interface {
   404  	CloseWrite() error
   405  }
   406  
   407  func TestEchoWithMessaging(t *testing.T) {
   408  	c := PipeConfig{
   409  		MessageMode:      true,  // Use message mode so that CloseWrite() is supported
   410  		InputBufferSize:  65536, // Use 64KB buffers to improve performance
   411  		OutputBufferSize: 65536,
   412  	}
   413  	l, err := ListenPipe(testPipeName, &c)
   414  	if err != nil {
   415  		t.Fatal(err)
   416  	}
   417  	defer l.Close()
   418  
   419  	listenerDone := make(chan bool)
   420  	clientDone := make(chan bool)
   421  	go func() {
   422  		// server echo
   423  		conn, e := l.Accept()
   424  		if e != nil {
   425  			t.Error(err)
   426  			return
   427  		}
   428  		defer conn.Close()
   429  
   430  		time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
   431  		_, _ = io.Copy(conn, conn)
   432  		_ = conn.(CloseWriter).CloseWrite()
   433  		close(listenerDone)
   434  	}()
   435  	timeout := 1 * time.Second
   436  	client, err := DialPipe(testPipeName, &timeout)
   437  	if err != nil {
   438  		t.Fatal(err)
   439  	}
   440  	defer client.Close()
   441  
   442  	go func() {
   443  		// client read back
   444  		bytes := make([]byte, 2)
   445  		n, e := client.Read(bytes)
   446  		if e != nil {
   447  			t.Error(err)
   448  			return
   449  		}
   450  		if n != 2 {
   451  			t.Errorf("expected 2 bytes, got %v", n)
   452  			return
   453  		}
   454  		close(clientDone)
   455  	}()
   456  
   457  	payload := make([]byte, 2)
   458  	payload[0] = 0
   459  	payload[1] = 1
   460  
   461  	n, err := client.Write(payload)
   462  	if err != nil {
   463  		t.Fatal(err)
   464  	}
   465  	if n != 2 {
   466  		t.Fatalf("expected 2 bytes, got %v", n)
   467  	}
   468  	_ = client.(CloseWriter).CloseWrite()
   469  	<-listenerDone
   470  	<-clientDone
   471  }
   472  
   473  func TestConnectRace(t *testing.T) {
   474  	l, err := ListenPipe(testPipeName, nil)
   475  	if err != nil {
   476  		t.Fatal(err)
   477  	}
   478  	defer l.Close()
   479  	go func() {
   480  		for {
   481  			s, err := l.Accept()
   482  			if errors.Is(err, ErrPipeListenerClosed) {
   483  				return
   484  			}
   485  
   486  			if err != nil {
   487  				t.Error(err)
   488  				return
   489  			}
   490  			s.Close()
   491  		}
   492  	}()
   493  
   494  	for i := 0; i < 1000; i++ {
   495  		c, err := DialPipe(testPipeName, nil)
   496  		if err != nil {
   497  			t.Fatal(err)
   498  		}
   499  		c.Close()
   500  	}
   501  }
   502  
   503  func TestMessageReadMode(t *testing.T) {
   504  	var wg sync.WaitGroup
   505  	defer wg.Wait()
   506  
   507  	l, err := ListenPipe(testPipeName, &PipeConfig{MessageMode: true})
   508  	if err != nil {
   509  		t.Fatal(err)
   510  	}
   511  	defer l.Close()
   512  
   513  	msg := ([]byte)("hello world")
   514  
   515  	wg.Add(1)
   516  	go func() {
   517  		defer wg.Done()
   518  		s, err := l.Accept()
   519  		if err != nil {
   520  			t.Error(err)
   521  			return
   522  		}
   523  		_, err = s.Write(msg)
   524  		if err != nil {
   525  			t.Error(err)
   526  			return
   527  		}
   528  		s.Close()
   529  	}()
   530  
   531  	c, err := DialPipe(testPipeName, nil)
   532  	if err != nil {
   533  		t.Fatal(err)
   534  	}
   535  	defer c.Close()
   536  
   537  	setNamedPipeHandleState := syscall.NewLazyDLL("kernel32.dll").NewProc("SetNamedPipeHandleState")
   538  
   539  	p := c.(*win32MessageBytePipe)
   540  	mode := uint32(windows.PIPE_READMODE_MESSAGE)
   541  	if s, _, err := setNamedPipeHandleState.Call(uintptr(p.handle), uintptr(unsafe.Pointer(&mode)), 0, 0); s == 0 {
   542  		t.Fatal(err)
   543  	}
   544  
   545  	ch := make([]byte, 1)
   546  	var vmsg []byte
   547  	for {
   548  		n, err := c.Read(ch)
   549  		if err == io.EOF { //nolint:errorlint
   550  			break
   551  		}
   552  		if err != nil {
   553  			t.Fatal(err)
   554  		}
   555  		if n != 1 {
   556  			t.Fatal("expected 1: ", n)
   557  		}
   558  		vmsg = append(vmsg, ch[0])
   559  	}
   560  	if !bytes.Equal(msg, vmsg) {
   561  		t.Fatalf("expected %s: %s", msg, vmsg)
   562  	}
   563  }
   564  
   565  func TestListenConnectRace(t *testing.T) {
   566  	for i := 0; i < 50 && !t.Failed(); i++ {
   567  		var wg sync.WaitGroup
   568  		wg.Add(1)
   569  		go func() {
   570  			c, err := DialPipe(testPipeName, nil)
   571  			if err == nil {
   572  				c.Close()
   573  			}
   574  			wg.Done()
   575  		}()
   576  		s, err := ListenPipe(testPipeName, nil)
   577  		if err != nil {
   578  			t.Error(i, err)
   579  		} else {
   580  			s.Close()
   581  		}
   582  		wg.Wait()
   583  	}
   584  }
   585  

View as plain text