1
2
3
4 package winio
5
6 import (
7 "bufio"
8 "bytes"
9 "context"
10 "errors"
11 "io"
12 "net"
13 "sync"
14 "syscall"
15 "testing"
16 "time"
17 "unsafe"
18
19 "golang.org/x/sys/windows"
20 )
21
22 var testPipeName = `\\.\pipe\winiotestpipe`
23
24 var aLongTimeAgo = time.Unix(1, 0)
25
26 func TestDialUnknownFailsImmediately(t *testing.T) {
27 _, err := DialPipe(testPipeName, nil)
28 if !errors.Is(err, syscall.ENOENT) {
29 t.Fatalf("expected ENOENT got %v", err)
30 }
31 }
32
33 func TestDialListenerTimesOut(t *testing.T) {
34 l, err := ListenPipe(testPipeName, nil)
35 if err != nil {
36 t.Fatal(err)
37 }
38 defer l.Close()
39 var d = 10 * time.Millisecond
40 _, err = DialPipe(testPipeName, &d)
41 if !errors.Is(err, ErrTimeout) {
42 t.Fatalf("expected ErrTimeout, got %v", err)
43 }
44 }
45
46 func TestDialContextListenerTimesOut(t *testing.T) {
47 l, err := ListenPipe(testPipeName, nil)
48 if err != nil {
49 t.Fatal(err)
50 }
51 defer l.Close()
52 var d = 10 * time.Millisecond
53 ctx, cancel := context.WithTimeout(context.Background(), d)
54 defer cancel()
55 _, err = DialPipeContext(ctx, testPipeName)
56 if !errors.Is(err, context.DeadlineExceeded) {
57 t.Fatalf("expected context.DeadlineExceeded, got %v", err)
58 }
59 }
60
61 func TestDialListenerGetsCancelled(t *testing.T) {
62 ctx, cancel := context.WithCancel(context.Background())
63 l, err := ListenPipe(testPipeName, nil)
64 if err != nil {
65 t.Fatal(err)
66 }
67 ch := make(chan error)
68 defer l.Close()
69 go func(ctx context.Context, ch chan error) {
70 _, err := DialPipeContext(ctx, testPipeName)
71 ch <- err
72 }(ctx, ch)
73 time.Sleep(time.Millisecond * 30)
74 cancel()
75 err = <-ch
76 if !errors.Is(err, context.Canceled) {
77 t.Fatalf("expected context.Canceled, got %v", err)
78 }
79 }
80
81 func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
82 c := PipeConfig{
83 SecurityDescriptor: "D:P(A;;0x1200FF;;;WD)",
84 }
85 l, err := ListenPipe(testPipeName, &c)
86 if err != nil {
87 t.Fatal(err)
88 }
89 defer l.Close()
90 _, err = DialPipe(testPipeName, nil)
91 if !errors.Is(err, syscall.ERROR_ACCESS_DENIED) {
92 t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
93 }
94 }
95
96 func getConnection(cfg *PipeConfig) (client net.Conn, server net.Conn, err error) {
97 l, err := ListenPipe(testPipeName, cfg)
98 if err != nil {
99 return nil, nil, err
100 }
101 defer l.Close()
102
103 type response struct {
104 c net.Conn
105 err error
106 }
107 ch := make(chan response)
108 go func() {
109 c, err := l.Accept()
110 ch <- response{c, err}
111 }()
112
113 c, err := DialPipe(testPipeName, nil)
114 if err != nil {
115 return client, server, err
116 }
117
118 r := <-ch
119 if err = r.err; err != nil {
120 c.Close()
121 return nil, nil, err
122 }
123
124 return c, r.c, nil
125 }
126
127 func TestReadTimeout(t *testing.T) {
128 c, s, err := getConnection(nil)
129 if err != nil {
130 t.Fatal(err)
131 }
132 defer c.Close()
133 defer s.Close()
134
135 _ = c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
136
137 buf := make([]byte, 10)
138 _, err = c.Read(buf)
139 if !errors.Is(err, ErrTimeout) {
140 t.Fatalf("expected ErrTimeout, got %v", err)
141 }
142 }
143
144 func server(l net.Listener, ch chan int) {
145 c, err := l.Accept()
146 if err != nil {
147 panic(err)
148 }
149 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
150 s, err := rw.ReadString('\n')
151 if err != nil {
152 panic(err)
153 }
154 _, err = rw.WriteString("got " + s)
155 if err != nil {
156 panic(err)
157 }
158 err = rw.Flush()
159 if err != nil {
160 panic(err)
161 }
162 c.Close()
163 ch <- 1
164 }
165
166 func TestFullListenDialReadWrite(t *testing.T) {
167 l, err := ListenPipe(testPipeName, nil)
168 if err != nil {
169 t.Fatal(err)
170 }
171 defer l.Close()
172
173 ch := make(chan int)
174 go server(l, ch)
175
176 c, err := DialPipe(testPipeName, nil)
177 if err != nil {
178 t.Fatal(err)
179 }
180 defer c.Close()
181
182 rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
183 _, err = rw.WriteString("hello world\n")
184 if err != nil {
185 t.Fatal(err)
186 }
187 err = rw.Flush()
188 if err != nil {
189 t.Fatal(err)
190 }
191
192 s, err := rw.ReadString('\n')
193 if err != nil {
194 t.Fatal(err)
195 }
196 ms := "got hello world\n"
197 if s != ms {
198 t.Errorf("expected '%s', got '%s'", ms, s)
199 }
200
201 <-ch
202 }
203
204 func TestCloseAbortsListen(t *testing.T) {
205 l, err := ListenPipe(testPipeName, nil)
206 if err != nil {
207 t.Fatal(err)
208 }
209
210 ch := make(chan error)
211 go func() {
212 _, err := l.Accept()
213 ch <- err
214 }()
215
216 time.Sleep(30 * time.Millisecond)
217 l.Close()
218
219 err = <-ch
220 if !errors.Is(err, ErrPipeListenerClosed) {
221 t.Fatalf("expected ErrPipeListenerClosed, got %v", err)
222 }
223 }
224
225 func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
226 b := make([]byte, 10)
227 w.Close()
228 n, err := r.Read(b)
229 if n > 0 {
230 t.Errorf("unexpected byte count %d", n)
231 }
232 if err != io.EOF {
233 t.Errorf("expected EOF: %v", err)
234 }
235 }
236
237 func TestCloseClientEOFServer(t *testing.T) {
238 c, s, err := getConnection(nil)
239 if err != nil {
240 t.Fatal(err)
241 }
242 defer c.Close()
243 defer s.Close()
244 ensureEOFOnClose(t, c, s)
245 }
246
247 func TestCloseServerEOFClient(t *testing.T) {
248 c, s, err := getConnection(nil)
249 if err != nil {
250 t.Fatal(err)
251 }
252 defer c.Close()
253 defer s.Close()
254 ensureEOFOnClose(t, s, c)
255 }
256
257 func TestCloseWriteEOF(t *testing.T) {
258 cfg := &PipeConfig{
259 MessageMode: true,
260 }
261 c, s, err := getConnection(cfg)
262 if err != nil {
263 t.Fatal(err)
264 }
265 defer c.Close()
266 defer s.Close()
267
268 type closeWriter interface {
269 CloseWrite() error
270 }
271
272 err = c.(closeWriter).CloseWrite()
273 if err != nil {
274 t.Fatal(err)
275 }
276
277 b := make([]byte, 10)
278 _, err = s.Read(b)
279 if !errors.Is(err, io.EOF) {
280 t.Fatal(err)
281 }
282 }
283
284 func TestAcceptAfterCloseFails(t *testing.T) {
285 l, err := ListenPipe(testPipeName, nil)
286 if err != nil {
287 t.Fatal(err)
288 }
289 l.Close()
290 _, err = l.Accept()
291 if !errors.Is(err, ErrPipeListenerClosed) {
292 t.Fatalf("expected ErrPipeListenerClosed, got %v", err)
293 }
294 }
295
296 func TestDialTimesOutByDefault(t *testing.T) {
297 l, err := ListenPipe(testPipeName, nil)
298 if err != nil {
299 t.Fatal(err)
300 }
301 defer l.Close()
302 _, err = DialPipe(testPipeName, nil)
303 if !errors.Is(err, ErrTimeout) {
304 t.Fatalf("expected ErrTimeout, got %v", err)
305 }
306 }
307
308 func TestTimeoutPendingRead(t *testing.T) {
309 l, err := ListenPipe(testPipeName, nil)
310 if err != nil {
311 t.Fatal(err)
312 }
313 defer l.Close()
314
315 serverDone := make(chan struct{})
316
317 go func() {
318 s, err := l.Accept()
319 if err != nil {
320 t.Error(err)
321 return
322 }
323 time.Sleep(1 * time.Second)
324 s.Close()
325 close(serverDone)
326 }()
327
328 client, err := DialPipe(testPipeName, nil)
329 if err != nil {
330 t.Fatal(err)
331 }
332 defer client.Close()
333
334 clientErr := make(chan error)
335 go func() {
336 buf := make([]byte, 10)
337 _, err = client.Read(buf)
338 clientErr <- err
339 }()
340
341 time.Sleep(100 * time.Millisecond)
342 _ = client.SetReadDeadline(aLongTimeAgo)
343
344 select {
345 case err = <-clientErr:
346 if !errors.Is(err, ErrTimeout) {
347 t.Fatalf("expected ErrTimeout, got %v", err)
348 }
349 case <-time.After(100 * time.Millisecond):
350 t.Fatalf("timed out while waiting for read to cancel")
351 <-clientErr
352 }
353 <-serverDone
354 }
355
356 func TestTimeoutPendingWrite(t *testing.T) {
357 l, err := ListenPipe(testPipeName, nil)
358 if err != nil {
359 t.Fatal(err)
360 }
361 defer l.Close()
362
363 serverDone := make(chan struct{})
364
365 go func() {
366 s, err := l.Accept()
367 if err != nil {
368 t.Error(err)
369 return
370 }
371 time.Sleep(1 * time.Second)
372 s.Close()
373 close(serverDone)
374 }()
375
376 client, err := DialPipe(testPipeName, nil)
377 if err != nil {
378 t.Fatal(err)
379 }
380 defer client.Close()
381
382 clientErr := make(chan error)
383 go func() {
384 _, err = client.Write([]byte("this should timeout"))
385 clientErr <- err
386 }()
387
388 time.Sleep(100 * time.Millisecond)
389 _ = client.SetWriteDeadline(aLongTimeAgo)
390
391 select {
392 case err = <-clientErr:
393 if !errors.Is(err, ErrTimeout) {
394 t.Fatalf("expected ErrTimeout, got %v", err)
395 }
396 case <-time.After(100 * time.Millisecond):
397 t.Fatalf("timed out while waiting for write to cancel")
398 <-clientErr
399 }
400 <-serverDone
401 }
402
403 type CloseWriter interface {
404 CloseWrite() error
405 }
406
407 func TestEchoWithMessaging(t *testing.T) {
408 c := PipeConfig{
409 MessageMode: true,
410 InputBufferSize: 65536,
411 OutputBufferSize: 65536,
412 }
413 l, err := ListenPipe(testPipeName, &c)
414 if err != nil {
415 t.Fatal(err)
416 }
417 defer l.Close()
418
419 listenerDone := make(chan bool)
420 clientDone := make(chan bool)
421 go func() {
422
423 conn, e := l.Accept()
424 if e != nil {
425 t.Error(err)
426 return
427 }
428 defer conn.Close()
429
430 time.Sleep(500 * time.Millisecond)
431 _, _ = io.Copy(conn, conn)
432 _ = conn.(CloseWriter).CloseWrite()
433 close(listenerDone)
434 }()
435 timeout := 1 * time.Second
436 client, err := DialPipe(testPipeName, &timeout)
437 if err != nil {
438 t.Fatal(err)
439 }
440 defer client.Close()
441
442 go func() {
443
444 bytes := make([]byte, 2)
445 n, e := client.Read(bytes)
446 if e != nil {
447 t.Error(err)
448 return
449 }
450 if n != 2 {
451 t.Errorf("expected 2 bytes, got %v", n)
452 return
453 }
454 close(clientDone)
455 }()
456
457 payload := make([]byte, 2)
458 payload[0] = 0
459 payload[1] = 1
460
461 n, err := client.Write(payload)
462 if err != nil {
463 t.Fatal(err)
464 }
465 if n != 2 {
466 t.Fatalf("expected 2 bytes, got %v", n)
467 }
468 _ = client.(CloseWriter).CloseWrite()
469 <-listenerDone
470 <-clientDone
471 }
472
473 func TestConnectRace(t *testing.T) {
474 l, err := ListenPipe(testPipeName, nil)
475 if err != nil {
476 t.Fatal(err)
477 }
478 defer l.Close()
479 go func() {
480 for {
481 s, err := l.Accept()
482 if errors.Is(err, ErrPipeListenerClosed) {
483 return
484 }
485
486 if err != nil {
487 t.Error(err)
488 return
489 }
490 s.Close()
491 }
492 }()
493
494 for i := 0; i < 1000; i++ {
495 c, err := DialPipe(testPipeName, nil)
496 if err != nil {
497 t.Fatal(err)
498 }
499 c.Close()
500 }
501 }
502
503 func TestMessageReadMode(t *testing.T) {
504 var wg sync.WaitGroup
505 defer wg.Wait()
506
507 l, err := ListenPipe(testPipeName, &PipeConfig{MessageMode: true})
508 if err != nil {
509 t.Fatal(err)
510 }
511 defer l.Close()
512
513 msg := ([]byte)("hello world")
514
515 wg.Add(1)
516 go func() {
517 defer wg.Done()
518 s, err := l.Accept()
519 if err != nil {
520 t.Error(err)
521 return
522 }
523 _, err = s.Write(msg)
524 if err != nil {
525 t.Error(err)
526 return
527 }
528 s.Close()
529 }()
530
531 c, err := DialPipe(testPipeName, nil)
532 if err != nil {
533 t.Fatal(err)
534 }
535 defer c.Close()
536
537 setNamedPipeHandleState := syscall.NewLazyDLL("kernel32.dll").NewProc("SetNamedPipeHandleState")
538
539 p := c.(*win32MessageBytePipe)
540 mode := uint32(windows.PIPE_READMODE_MESSAGE)
541 if s, _, err := setNamedPipeHandleState.Call(uintptr(p.handle), uintptr(unsafe.Pointer(&mode)), 0, 0); s == 0 {
542 t.Fatal(err)
543 }
544
545 ch := make([]byte, 1)
546 var vmsg []byte
547 for {
548 n, err := c.Read(ch)
549 if err == io.EOF {
550 break
551 }
552 if err != nil {
553 t.Fatal(err)
554 }
555 if n != 1 {
556 t.Fatal("expected 1: ", n)
557 }
558 vmsg = append(vmsg, ch[0])
559 }
560 if !bytes.Equal(msg, vmsg) {
561 t.Fatalf("expected %s: %s", msg, vmsg)
562 }
563 }
564
565 func TestListenConnectRace(t *testing.T) {
566 for i := 0; i < 50 && !t.Failed(); i++ {
567 var wg sync.WaitGroup
568 wg.Add(1)
569 go func() {
570 c, err := DialPipe(testPipeName, nil)
571 if err == nil {
572 c.Close()
573 }
574 wg.Done()
575 }()
576 s, err := ListenPipe(testPipeName, nil)
577 if err != nil {
578 t.Error(i, err)
579 } else {
580 s.Close()
581 }
582 wg.Wait()
583 }
584 }
585
View as plain text