1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package cmux
16
17 import (
18 "bytes"
19 "crypto/rand"
20 "crypto/tls"
21 "errors"
22 "fmt"
23 "go/build"
24 "io"
25 "io/ioutil"
26 "log"
27 "net"
28 "net/http"
29 "net/rpc"
30 "os"
31 "os/exec"
32 "runtime"
33 "sort"
34 "strings"
35 "sync"
36 "sync/atomic"
37 "testing"
38 "time"
39
40 "golang.org/x/net/http2"
41 "golang.org/x/net/http2/hpack"
42 )
43
44 const (
45 testHTTP1Resp = "http1"
46 rpcVal = 1234
47 )
48
49 func safeServe(errCh chan<- error, muxl CMux) {
50 if err := muxl.Serve(); !strings.Contains(err.Error(), "use of closed") {
51 errCh <- err
52 }
53 }
54
55 func safeDial(t *testing.T, addr net.Addr) (*rpc.Client, func()) {
56 c, err := rpc.Dial(addr.Network(), addr.String())
57 if err != nil {
58 t.Fatal(err)
59 }
60 return c, func() {
61 if err := c.Close(); err != nil {
62 t.Fatal(err)
63 }
64 }
65 }
66
67 type chanListener struct {
68 net.Listener
69 connCh chan net.Conn
70 }
71
72 func newChanListener() *chanListener {
73 return &chanListener{connCh: make(chan net.Conn, 1)}
74 }
75
76 func (l *chanListener) Accept() (net.Conn, error) {
77 if c, ok := <-l.connCh; ok {
78 return c, nil
79 }
80 return nil, errors.New("use of closed network connection")
81 }
82
83 func testListener(t *testing.T) (net.Listener, func()) {
84 l, err := net.Listen("tcp", "127.0.0.1:0")
85 if err != nil {
86 t.Fatal(err)
87 }
88 var once sync.Once
89 return l, func() {
90 once.Do(func() {
91 if err := l.Close(); err != nil {
92 t.Fatal(err)
93 }
94 })
95 }
96 }
97
98 type testHTTP1Handler struct{}
99
100 func (h *testHTTP1Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
101 fmt.Fprintf(w, testHTTP1Resp)
102 }
103
104 func runTestHTTPServer(errCh chan<- error, l net.Listener) {
105 var mu sync.Mutex
106 conns := make(map[net.Conn]struct{})
107
108 defer func() {
109 mu.Lock()
110 for c := range conns {
111 if err := c.Close(); err != nil {
112 errCh <- err
113 }
114 }
115 mu.Unlock()
116 }()
117
118 s := &http.Server{
119 Handler: &testHTTP1Handler{},
120 ConnState: func(c net.Conn, state http.ConnState) {
121 mu.Lock()
122 switch state {
123 case http.StateNew:
124 conns[c] = struct{}{}
125 case http.StateClosed:
126 delete(conns, c)
127 }
128 mu.Unlock()
129 },
130 }
131 if err := s.Serve(l); err != ErrListenerClosed && err != ErrServerClosed {
132 errCh <- err
133 }
134 }
135
136 func generateTLSCert(t *testing.T) {
137 err := exec.Command("go", "run", build.Default.GOROOT+"/src/crypto/tls/generate_cert.go", "--host", "*").Run()
138 if err != nil {
139 t.Fatal(err)
140 }
141 }
142
143 func cleanupTLSCert(t *testing.T) {
144 err := os.Remove("cert.pem")
145 if err != nil {
146 t.Error(err)
147 }
148 err = os.Remove("key.pem")
149 if err != nil {
150 t.Error(err)
151 }
152 }
153
154 func runTestTLSServer(errCh chan<- error, l net.Listener) {
155 certificate, err := tls.LoadX509KeyPair("cert.pem", "key.pem")
156 if err != nil {
157 errCh <- err
158 log.Printf("1")
159 return
160 }
161
162 config := &tls.Config{
163 Certificates: []tls.Certificate{certificate},
164 Rand: rand.Reader,
165 }
166
167 tlsl := tls.NewListener(l, config)
168 runTestHTTPServer(errCh, tlsl)
169 }
170
171 func runTestHTTP1Client(t *testing.T, addr net.Addr) {
172 runTestHTTPClient(t, "http", addr)
173 }
174
175 func runTestTLSClient(t *testing.T, addr net.Addr) {
176 runTestHTTPClient(t, "https", addr)
177 }
178
179 func runTestHTTPClient(t *testing.T, proto string, addr net.Addr) {
180 client := http.Client{
181 Timeout: 5 * time.Second,
182 Transport: &http.Transport{
183 TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
184 },
185 }
186 r, err := client.Get(proto + "://" + addr.String())
187 if err != nil {
188 t.Fatal(err)
189 }
190
191 defer func() {
192 if err = r.Body.Close(); err != nil {
193 t.Fatal(err)
194 }
195 }()
196
197 b, err := ioutil.ReadAll(r.Body)
198 if err != nil {
199 t.Fatal(err)
200 }
201 if string(b) != testHTTP1Resp {
202 t.Fatalf("invalid response: want=%s got=%s", testHTTP1Resp, b)
203 }
204 }
205
206 type TestRPCRcvr struct{}
207
208 func (r TestRPCRcvr) Test(i int, j *int) error {
209 *j = i
210 return nil
211 }
212
213 func runTestRPCServer(errCh chan<- error, l net.Listener) {
214 s := rpc.NewServer()
215 if err := s.Register(TestRPCRcvr{}); err != nil {
216 errCh <- err
217 }
218 for {
219 c, err := l.Accept()
220 if err != nil {
221 if err != ErrListenerClosed && err != ErrServerClosed {
222 errCh <- err
223 }
224 return
225 }
226 go s.ServeConn(c)
227 }
228 }
229
230 func runTestRPCClient(t *testing.T, addr net.Addr) {
231 c, cleanup := safeDial(t, addr)
232 defer cleanup()
233
234 var num int
235 if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err != nil {
236 t.Fatal(err)
237 }
238
239 if num != rpcVal {
240 t.Errorf("wrong rpc response: want=%d got=%v", rpcVal, num)
241 }
242 }
243
244 const (
245 handleHTTP1Close = 1
246 handleHTTP1Request = 2
247 handleAnyClose = 3
248 handleAnyRequest = 4
249 )
250
251 func TestTimeout(t *testing.T) {
252 defer leakCheck(t)()
253 lis, Close := testListener(t)
254 defer Close()
255 result := make(chan int, 5)
256 testDuration := time.Millisecond * 500
257 m := New(lis)
258 m.SetReadTimeout(testDuration)
259 http1 := m.Match(HTTP1Fast())
260 any := m.Match(Any())
261 go func() {
262 _ = m.Serve()
263 }()
264 go func() {
265 con, err := http1.Accept()
266 if err != nil {
267 result <- handleHTTP1Close
268 } else {
269 _, _ = con.Write([]byte("http1"))
270 _ = con.Close()
271 result <- handleHTTP1Request
272 }
273 }()
274 go func() {
275 con, err := any.Accept()
276 if err != nil {
277 result <- handleAnyClose
278 } else {
279 _, _ = con.Write([]byte("any"))
280 _ = con.Close()
281 result <- handleAnyRequest
282 }
283 }()
284 time.Sleep(testDuration)
285 client, err := net.Dial("tcp", lis.Addr().String())
286 if err != nil {
287 log.Fatal("testTimeout client failed: ", err)
288 }
289 defer func() {
290 _ = client.Close()
291 }()
292 time.Sleep(testDuration / 2)
293 if len(result) != 0 {
294 log.Print("tcp ")
295 t.Fatal("testTimeout failed: accepted to fast: ", len(result))
296 }
297 _ = client.SetReadDeadline(time.Now().Add(testDuration * 3))
298 buffer := make([]byte, 10)
299 rl, err := client.Read(buffer)
300 if err != nil {
301 t.Fatal("testTimeout failed: client error: ", err, rl)
302 }
303 Close()
304 if rl != 3 {
305 log.Print("testTimeout failed: response from wrong sevice ", rl)
306 }
307 if string(buffer[0:3]) != "any" {
308 log.Print("testTimeout failed: response from wrong sevice ")
309 }
310 time.Sleep(testDuration * 2)
311 if len(result) != 2 {
312 t.Fatal("testTimeout failed: accepted to less: ", len(result))
313 }
314 if a := <-result; a != handleAnyRequest {
315 t.Fatal("testTimeout failed: any rule did not match")
316 }
317 if a := <-result; a != handleHTTP1Close {
318 t.Fatal("testTimeout failed: no close an http rule")
319 }
320 }
321
322 func TestRead(t *testing.T) {
323 defer leakCheck(t)()
324 errCh := make(chan error)
325 defer func() {
326 select {
327 case err := <-errCh:
328 t.Fatal(err)
329 default:
330 }
331 }()
332 const payload = "hello world\r\n"
333 const mult = 2
334
335 writer, reader := net.Pipe()
336 go func() {
337 if _, err := io.WriteString(writer, strings.Repeat(payload, mult)); err != nil {
338 t.Fatal(err)
339 }
340 if err := writer.Close(); err != nil {
341 t.Fatal(err)
342 }
343 }()
344
345 l := newChanListener()
346 defer close(l.connCh)
347 l.connCh <- reader
348 muxl := New(l)
349
350
351
352 muxl.Match(func(r io.Reader) bool {
353 var b [len(payload)]byte
354 _, _ = r.Read(b[:])
355 return false
356 })
357 anyl := muxl.Match(Any())
358 go safeServe(errCh, muxl)
359 muxedConn, err := anyl.Accept()
360 if err != nil {
361 t.Fatal(err)
362 }
363 for i := 0; i < mult; i++ {
364 var b [len(payload)]byte
365 n, err := muxedConn.Read(b[:])
366 if err != nil {
367 t.Error(err)
368 continue
369 }
370 if e := len(b); n != e {
371 t.Errorf("expected to read %d bytes, but read %d bytes", e, n)
372 }
373 }
374 var b [1]byte
375 if _, err := muxedConn.Read(b[:]); err != io.EOF {
376 t.Errorf("unexpected error %v, expected %v", err, io.EOF)
377 }
378 }
379
380 func TestAny(t *testing.T) {
381 defer leakCheck(t)()
382 errCh := make(chan error)
383 defer func() {
384 select {
385 case err := <-errCh:
386 t.Fatal(err)
387 default:
388 }
389 }()
390 l, cleanup := testListener(t)
391 defer cleanup()
392
393 muxl := New(l)
394 httpl := muxl.Match(Any())
395
396 go runTestHTTPServer(errCh, httpl)
397 go safeServe(errCh, muxl)
398
399 runTestHTTP1Client(t, l.Addr())
400 }
401
402 func TestTLS(t *testing.T) {
403 generateTLSCert(t)
404 defer cleanupTLSCert(t)
405 defer leakCheck(t)()
406 errCh := make(chan error)
407 defer func() {
408 select {
409 case err := <-errCh:
410 t.Fatal(err)
411 default:
412 }
413 }()
414 l, cleanup := testListener(t)
415 defer cleanup()
416
417 muxl := New(l)
418 tlsl := muxl.Match(TLS())
419 httpl := muxl.Match(Any())
420
421 go runTestTLSServer(errCh, tlsl)
422 go runTestHTTPServer(errCh, httpl)
423 go safeServe(errCh, muxl)
424
425 runTestHTTP1Client(t, l.Addr())
426 runTestTLSClient(t, l.Addr())
427 }
428
429 func TestHTTP2(t *testing.T) {
430 defer leakCheck(t)()
431 errCh := make(chan error)
432 defer func() {
433 select {
434 case err := <-errCh:
435 t.Fatal(err)
436 default:
437 }
438 }()
439 writer, reader := net.Pipe()
440 go func() {
441 if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
442 t.Fatal(err)
443 }
444 if err := writer.Close(); err != nil {
445 t.Fatal(err)
446 }
447 }()
448
449 l := newChanListener()
450 l.connCh <- reader
451 muxl := New(l)
452
453 muxl.Match(func(r io.Reader) bool {
454 var b [1]byte
455 _, _ = r.Read(b[:])
456 return false
457 })
458 h2l := muxl.Match(HTTP2())
459 go safeServe(errCh, muxl)
460 muxedConn, err := h2l.Accept()
461 close(l.connCh)
462 if err != nil {
463 t.Fatal(err)
464 }
465 var b [len(http2.ClientPreface)]byte
466 var n int
467
468 if n, err = muxedConn.Read(b[:]); err == io.EOF {
469 t.Fatal(err)
470 }
471
472 if _, err = muxedConn.Read(b[n:]); err != io.EOF {
473 t.Fatal(err)
474 }
475 if string(b[:]) != http2.ClientPreface {
476 t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
477 }
478 }
479
480 func TestHTTP2MatchHeaderField(t *testing.T) {
481 testHTTP2MatchHeaderField(t, HTTP2HeaderField, "value", "value", "anothervalue")
482 }
483
484 func TestHTTP2MatchHeaderFieldPrefix(t *testing.T) {
485 testHTTP2MatchHeaderField(t, HTTP2HeaderFieldPrefix, "application/grpc+proto", "application/grpc", "application/json")
486 }
487
488 func testHTTP2MatchHeaderField(
489 t *testing.T,
490 matcherConstructor func(string, string) Matcher,
491 headerValue string,
492 matchValue string,
493 notMatchValue string,
494 ) {
495 defer leakCheck(t)()
496 errCh := make(chan error)
497 defer func() {
498 select {
499 case err := <-errCh:
500 t.Fatal(err)
501 default:
502 }
503 }()
504 name := "name"
505 writer, reader := net.Pipe()
506 go func() {
507 if _, err := io.WriteString(writer, http2.ClientPreface); err != nil {
508 t.Fatal(err)
509 }
510 var buf bytes.Buffer
511 enc := hpack.NewEncoder(&buf)
512 if err := enc.WriteField(hpack.HeaderField{Name: name, Value: headerValue}); err != nil {
513 t.Fatal(err)
514 }
515 framer := http2.NewFramer(writer, nil)
516 err := framer.WriteHeaders(http2.HeadersFrameParam{
517 StreamID: 1,
518 BlockFragment: buf.Bytes(),
519 EndStream: true,
520 EndHeaders: true,
521 })
522 if err != nil {
523 t.Fatal(err)
524 }
525 if err := writer.Close(); err != nil {
526 t.Fatal(err)
527 }
528 }()
529
530 l := newChanListener()
531 l.connCh <- reader
532 muxl := New(l)
533
534 muxl.Match(func(r io.Reader) bool {
535 var b [1]byte
536 _, _ = r.Read(b[:])
537 return false
538 })
539
540 muxl.Match(matcherConstructor(name, notMatchValue))
541
542 h2l := muxl.Match(matcherConstructor(name, matchValue))
543 go safeServe(errCh, muxl)
544 muxedConn, err := h2l.Accept()
545 close(l.connCh)
546 if err != nil {
547 t.Fatal(err)
548 }
549 var b [len(http2.ClientPreface)]byte
550
551 if _, err := muxedConn.Read(b[:]); err == io.EOF {
552 t.Fatal(err)
553 }
554 if string(b[:]) != http2.ClientPreface {
555 t.Errorf("got unexpected read %s, expected %s", b, http2.ClientPreface)
556 }
557 }
558
559 func TestHTTPGoRPC(t *testing.T) {
560 defer leakCheck(t)()
561 errCh := make(chan error)
562 defer func() {
563 select {
564 case err := <-errCh:
565 t.Fatal(err)
566 default:
567 }
568 }()
569 l, cleanup := testListener(t)
570 defer cleanup()
571
572 muxl := New(l)
573 httpl := muxl.Match(HTTP2(), HTTP1Fast())
574 rpcl := muxl.Match(Any())
575
576 go runTestHTTPServer(errCh, httpl)
577 go runTestRPCServer(errCh, rpcl)
578 go safeServe(errCh, muxl)
579
580 runTestHTTP1Client(t, l.Addr())
581 runTestRPCClient(t, l.Addr())
582 }
583
584 func TestErrorHandler(t *testing.T) {
585 defer leakCheck(t)()
586 errCh := make(chan error)
587 defer func() {
588 select {
589 case err := <-errCh:
590 t.Fatal(err)
591 default:
592 }
593 }()
594 l, cleanup := testListener(t)
595 defer cleanup()
596
597 muxl := New(l)
598 httpl := muxl.Match(HTTP2(), HTTP1Fast())
599
600 go runTestHTTPServer(errCh, httpl)
601 go safeServe(errCh, muxl)
602
603 var errCount uint32
604 muxl.HandleError(func(err error) bool {
605 if atomic.AddUint32(&errCount, 1) == 1 {
606 if _, ok := err.(ErrNotMatched); !ok {
607 t.Errorf("unexpected error: %v", err)
608 }
609 }
610 return true
611 })
612
613 c, cleanup := safeDial(t, l.Addr())
614 defer cleanup()
615
616 var num int
617 for atomic.LoadUint32(&errCount) == 0 {
618 if err := c.Call("TestRPCRcvr.Test", rpcVal, &num); err == nil {
619
620 t.Errorf("unexpected rpc success after %d errors", atomic.LoadUint32(&errCount))
621 }
622 }
623 }
624
625 func TestMultipleMatchers(t *testing.T) {
626 defer leakCheck(t)()
627 errCh := make(chan error)
628 defer func() {
629 select {
630 case err := <-errCh:
631 t.Fatal(err)
632 default:
633 }
634 }()
635 l, cleanup := testListener(t)
636 defer cleanup()
637
638 matcher := func(r io.Reader) bool {
639 return true
640 }
641 unmatcher := func(r io.Reader) bool {
642 return false
643 }
644
645 muxl := New(l)
646 lis := muxl.Match(unmatcher, matcher, unmatcher)
647
648 go runTestHTTPServer(errCh, lis)
649 go safeServe(errCh, muxl)
650
651 runTestHTTP1Client(t, l.Addr())
652 }
653
654 func TestListenerClose(t *testing.T) {
655 defer leakCheck(t)()
656 errCh := make(chan error)
657 defer func() {
658 select {
659 case err := <-errCh:
660 t.Fatal(err)
661 default:
662 }
663 }()
664 l := newChanListener()
665
666 c1, c2 := net.Pipe()
667
668 muxl := New(l)
669 anyl := muxl.Match(Any())
670
671 go safeServe(errCh, muxl)
672
673 l.connCh <- c1
674
675
676 if _, err := anyl.Accept(); err != nil {
677 t.Fatal(err)
678 }
679
680
681 l.connCh <- c2
682
683
684 close(l.connCh)
685
686
687 if _, err := anyl.Accept(); err != nil {
688 if err != ErrListenerClosed && err != ErrServerClosed {
689 t.Fatal(err)
690 }
691
692
693 if _, err := c2.Read([]byte{}); !strings.Contains(err.Error(), "closed") {
694 t.Fatalf("connection is not closed and is leaked: %v", err)
695 }
696 }
697 }
698
699 func TestClose(t *testing.T) {
700 defer leakCheck(t)()
701 errCh := make(chan error)
702 defer func() {
703 select {
704 case err := <-errCh:
705 t.Fatal(err)
706 default:
707 }
708 }()
709 l, cleanup := testListener(t)
710 defer cleanup()
711
712 muxl := New(l)
713 anyl := muxl.Match(Any())
714
715 go safeServe(errCh, muxl)
716
717 muxl.Close()
718
719 if _, err := anyl.Accept(); err != ErrServerClosed {
720 t.Fatal(err)
721 }
722 }
723
724
725
726
727
728 func interestingGoroutines() (gs []string) {
729 buf := make([]byte, 2<<20)
730 buf = buf[:runtime.Stack(buf, true)]
731 for _, g := range strings.Split(string(buf), "\n\n") {
732 sl := strings.SplitN(g, "\n", 2)
733 if len(sl) != 2 {
734 continue
735 }
736 stack := strings.TrimSpace(sl[1])
737 if strings.HasPrefix(stack, "testing.RunTests") {
738 continue
739 }
740
741 if stack == "" ||
742 strings.Contains(stack, "main.main()") ||
743 strings.Contains(stack, "testing.Main(") ||
744 strings.Contains(stack, "runtime.goexit") ||
745 strings.Contains(stack, "created by runtime.gc") ||
746 strings.Contains(stack, "interestingGoroutines") ||
747 strings.Contains(stack, "runtime.MHeap_Scavenger") {
748 continue
749 }
750 gs = append(gs, g)
751 }
752 sort.Strings(gs)
753 return
754 }
755
756
757
758
759 func leakCheck(t testing.TB) func() {
760 orig := map[string]bool{}
761 for _, g := range interestingGoroutines() {
762 orig[g] = true
763 }
764 return func() {
765
766
767 deadline := time.Now().Add(5 * time.Second)
768 for {
769 var leaked []string
770 for _, g := range interestingGoroutines() {
771 if !orig[g] {
772 leaked = append(leaked, g)
773 }
774 }
775 if len(leaked) == 0 {
776 return
777 }
778 if time.Now().Before(deadline) {
779 time.Sleep(50 * time.Millisecond)
780 continue
781 }
782 for _, g := range leaked {
783 t.Errorf("Leaked goroutine: %v", g)
784 }
785 return
786 }
787 }
788 }
789
View as plain text