1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package proxy
16
17 import (
18 "bytes"
19 "context"
20 "crypto/tls"
21 "fmt"
22 "io/ioutil"
23 "log"
24 "math/rand"
25 "net"
26 "net/http"
27 "net/url"
28 "os"
29 "strings"
30 "testing"
31 "time"
32
33 "github.com/stretchr/testify/assert"
34 "go.etcd.io/etcd/client/pkg/v3/transport"
35 "go.uber.org/zap/zaptest"
36
37 "go.uber.org/zap"
38 )
39
40 func TestServer_Unix_Insecure(t *testing.T) { testServer(t, "unix", false, false) }
41 func TestServer_TCP_Insecure(t *testing.T) { testServer(t, "tcp", false, false) }
42 func TestServer_Unix_Secure(t *testing.T) { testServer(t, "unix", true, false) }
43 func TestServer_TCP_Secure(t *testing.T) { testServer(t, "tcp", true, false) }
44 func TestServer_Unix_Insecure_DelayTx(t *testing.T) { testServer(t, "unix", false, true) }
45 func TestServer_TCP_Insecure_DelayTx(t *testing.T) { testServer(t, "tcp", false, true) }
46 func TestServer_Unix_Secure_DelayTx(t *testing.T) { testServer(t, "unix", true, true) }
47 func TestServer_TCP_Secure_DelayTx(t *testing.T) { testServer(t, "tcp", true, true) }
48
49 func testServer(t *testing.T, scheme string, secure bool, delayTx bool) {
50 lg := zaptest.NewLogger(t)
51 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
52 if scheme == "tcp" {
53 ln1, ln2 := listen(t, "tcp", "localhost:0", transport.TLSInfo{}), listen(t, "tcp", "localhost:0", transport.TLSInfo{})
54 srcAddr, dstAddr = ln1.Addr().String(), ln2.Addr().String()
55 ln1.Close()
56 ln2.Close()
57 } else {
58 defer func() {
59 os.RemoveAll(srcAddr)
60 os.RemoveAll(dstAddr)
61 }()
62 }
63 tlsInfo := createTLSInfo(lg, secure)
64 ln := listen(t, scheme, dstAddr, tlsInfo)
65 defer ln.Close()
66
67 cfg := ServerConfig{
68 Logger: lg,
69 From: url.URL{Scheme: scheme, Host: srcAddr},
70 To: url.URL{Scheme: scheme, Host: dstAddr},
71 }
72 if secure {
73 cfg.TLSInfo = tlsInfo
74 }
75 p := NewServer(cfg)
76 <-p.Ready()
77 defer p.Close()
78
79 data1 := []byte("Hello World!")
80 donec, writec := make(chan struct{}), make(chan []byte)
81
82 go func() {
83 defer close(donec)
84 for data := range writec {
85 send(t, data, scheme, srcAddr, tlsInfo)
86 }
87 }()
88
89 recvc := make(chan []byte, 1)
90 go func() {
91 for i := 0; i < 2; i++ {
92 recvc <- receive(t, ln)
93 }
94 }()
95
96 writec <- data1
97 now := time.Now()
98 if d := <-recvc; !bytes.Equal(data1, d) {
99 t.Fatalf("expected %q, got %q", string(data1), string(d))
100 }
101 took1 := time.Since(now)
102 t.Logf("took %v with no latency", took1)
103
104 lat, rv := 50*time.Millisecond, 5*time.Millisecond
105 if delayTx {
106 p.DelayTx(lat, rv)
107 }
108
109 data2 := []byte("new data")
110 writec <- data2
111 now = time.Now()
112 if d := <-recvc; !bytes.Equal(data2, d) {
113 t.Fatalf("expected %q, got %q", string(data2), string(d))
114 }
115 took2 := time.Since(now)
116 if delayTx {
117 t.Logf("took %v with latency %v+-%v", took2, lat, rv)
118 } else {
119 t.Logf("took %v with no latency", took2)
120 }
121
122 if delayTx {
123 p.UndelayTx()
124 if took2 < lat-rv {
125 t.Fatalf("expected took2 %v (with latency) > delay: %v", took2, lat-rv)
126 }
127 }
128
129 close(writec)
130 select {
131 case <-donec:
132 case <-time.After(3 * time.Second):
133 t.Fatal("took too long to write")
134 }
135
136 select {
137 case <-p.Done():
138 t.Fatal("unexpected done")
139 case err := <-p.Error():
140 t.Fatal(err)
141 default:
142 }
143
144 if err := p.Close(); err != nil {
145 t.Fatal(err)
146 }
147
148 select {
149 case <-p.Done():
150 case err := <-p.Error():
151 if !strings.HasPrefix(err.Error(), "accept ") &&
152 !strings.HasSuffix(err.Error(), "use of closed network connection") {
153 t.Fatal(err)
154 }
155 case <-time.After(3 * time.Second):
156 t.Fatal("took too long to close")
157 }
158 }
159
160 func createTLSInfo(lg *zap.Logger, secure bool) transport.TLSInfo {
161 if secure {
162 return transport.TLSInfo{
163 KeyFile: "../../tests/fixtures/server.key.insecure",
164 CertFile: "../../tests/fixtures/server.crt",
165 TrustedCAFile: "../../tests/fixtures/ca.crt",
166 ClientCertAuth: true,
167 Logger: lg,
168 }
169 }
170 return transport.TLSInfo{Logger: lg}
171 }
172
173 func TestServer_Unix_Insecure_DelayAccept(t *testing.T) { testServerDelayAccept(t, false) }
174 func TestServer_Unix_Secure_DelayAccept(t *testing.T) { testServerDelayAccept(t, true) }
175 func testServerDelayAccept(t *testing.T, secure bool) {
176 lg := zaptest.NewLogger(t)
177 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
178 defer func() {
179 os.RemoveAll(srcAddr)
180 os.RemoveAll(dstAddr)
181 }()
182 tlsInfo := createTLSInfo(lg, secure)
183 scheme := "unix"
184 ln := listen(t, scheme, dstAddr, tlsInfo)
185 defer ln.Close()
186
187 cfg := ServerConfig{
188 Logger: lg,
189 From: url.URL{Scheme: scheme, Host: srcAddr},
190 To: url.URL{Scheme: scheme, Host: dstAddr},
191 }
192 if secure {
193 cfg.TLSInfo = tlsInfo
194 }
195 p := NewServer(cfg)
196 <-p.Ready()
197 defer p.Close()
198
199 data := []byte("Hello World!")
200
201 now := time.Now()
202 send(t, data, scheme, srcAddr, tlsInfo)
203 if d := receive(t, ln); !bytes.Equal(data, d) {
204 t.Fatalf("expected %q, got %q", string(data), string(d))
205 }
206 took1 := time.Since(now)
207 t.Logf("took %v with no latency", took1)
208
209 lat, rv := 700*time.Millisecond, 10*time.Millisecond
210 p.DelayAccept(lat, rv)
211 defer p.UndelayAccept()
212 if err := p.ResetListener(); err != nil {
213 t.Fatal(err)
214 }
215 time.Sleep(200 * time.Millisecond)
216
217 now = time.Now()
218 send(t, data, scheme, srcAddr, tlsInfo)
219 if d := receive(t, ln); !bytes.Equal(data, d) {
220 t.Fatalf("expected %q, got %q", string(data), string(d))
221 }
222 took2 := time.Since(now)
223 t.Logf("took %v with latency %v±%v", took2, lat, rv)
224
225 if took1 >= took2 {
226 t.Fatalf("expected took1 %v < took2 %v", took1, took2)
227 }
228 }
229
230 func TestServer_PauseTx(t *testing.T) {
231 lg := zaptest.NewLogger(t)
232 scheme := "unix"
233 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
234 defer func() {
235 os.RemoveAll(srcAddr)
236 os.RemoveAll(dstAddr)
237 }()
238 ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
239 defer ln.Close()
240
241 p := NewServer(ServerConfig{
242 Logger: lg,
243 From: url.URL{Scheme: scheme, Host: srcAddr},
244 To: url.URL{Scheme: scheme, Host: dstAddr},
245 })
246 <-p.Ready()
247 defer p.Close()
248
249 p.PauseTx()
250
251 data := []byte("Hello World!")
252 send(t, data, scheme, srcAddr, transport.TLSInfo{})
253
254 recvc := make(chan []byte, 1)
255 go func() {
256 recvc <- receive(t, ln)
257 }()
258
259 select {
260 case d := <-recvc:
261 t.Fatalf("received unexpected data %q during pause", string(d))
262 case <-time.After(200 * time.Millisecond):
263 }
264
265 p.UnpauseTx()
266
267 select {
268 case d := <-recvc:
269 if !bytes.Equal(data, d) {
270 t.Fatalf("expected %q, got %q", string(data), string(d))
271 }
272 case <-time.After(2 * time.Second):
273 t.Fatal("took too long to receive after unpause")
274 }
275 }
276
277 func TestServer_ModifyTx_corrupt(t *testing.T) {
278 lg := zaptest.NewLogger(t)
279 scheme := "unix"
280 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
281 defer func() {
282 os.RemoveAll(srcAddr)
283 os.RemoveAll(dstAddr)
284 }()
285 ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
286 defer ln.Close()
287
288 p := NewServer(ServerConfig{
289 Logger: lg,
290 From: url.URL{Scheme: scheme, Host: srcAddr},
291 To: url.URL{Scheme: scheme, Host: dstAddr},
292 })
293 <-p.Ready()
294 defer p.Close()
295
296 p.ModifyTx(func(d []byte) []byte {
297 d[len(d)/2]++
298 return d
299 })
300 data := []byte("Hello World!")
301 send(t, data, scheme, srcAddr, transport.TLSInfo{})
302 if d := receive(t, ln); bytes.Equal(d, data) {
303 t.Fatalf("expected corrupted data, got %q", string(d))
304 }
305
306 p.UnmodifyTx()
307 send(t, data, scheme, srcAddr, transport.TLSInfo{})
308 if d := receive(t, ln); !bytes.Equal(d, data) {
309 t.Fatalf("expected uncorrupted data, got %q", string(d))
310 }
311 }
312
313 func TestServer_ModifyTx_packet_loss(t *testing.T) {
314 lg := zaptest.NewLogger(t)
315 scheme := "unix"
316 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
317 defer func() {
318 os.RemoveAll(srcAddr)
319 os.RemoveAll(dstAddr)
320 }()
321 ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
322 defer ln.Close()
323
324 p := NewServer(ServerConfig{
325 Logger: lg,
326 From: url.URL{Scheme: scheme, Host: srcAddr},
327 To: url.URL{Scheme: scheme, Host: dstAddr},
328 })
329 <-p.Ready()
330 defer p.Close()
331
332
333 p.ModifyTx(func(d []byte) []byte {
334 half := len(d) / 2
335 return d[:half:half]
336 })
337 data := []byte("Hello World!")
338 send(t, data, scheme, srcAddr, transport.TLSInfo{})
339 if d := receive(t, ln); bytes.Equal(d, data) {
340 t.Fatalf("expected corrupted data, got %q", string(d))
341 }
342
343 p.UnmodifyTx()
344 send(t, data, scheme, srcAddr, transport.TLSInfo{})
345 if d := receive(t, ln); !bytes.Equal(d, data) {
346 t.Fatalf("expected uncorrupted data, got %q", string(d))
347 }
348 }
349
350 func TestServer_BlackholeTx(t *testing.T) {
351 lg := zaptest.NewLogger(t)
352 scheme := "unix"
353 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
354 defer func() {
355 os.RemoveAll(srcAddr)
356 os.RemoveAll(dstAddr)
357 }()
358 ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
359 defer ln.Close()
360
361 p := NewServer(ServerConfig{
362 Logger: lg,
363 From: url.URL{Scheme: scheme, Host: srcAddr},
364 To: url.URL{Scheme: scheme, Host: dstAddr},
365 })
366 <-p.Ready()
367 defer p.Close()
368
369 p.BlackholeTx()
370
371 data := []byte("Hello World!")
372 send(t, data, scheme, srcAddr, transport.TLSInfo{})
373
374 recvc := make(chan []byte, 1)
375 go func() {
376 recvc <- receive(t, ln)
377 }()
378
379 select {
380 case d := <-recvc:
381 t.Fatalf("unexpected data receive %q during blackhole", string(d))
382 case <-time.After(200 * time.Millisecond):
383 }
384
385 p.UnblackholeTx()
386
387
388 data[0]++
389 send(t, data, scheme, srcAddr, transport.TLSInfo{})
390
391 select {
392 case d := <-recvc:
393 if !bytes.Equal(data, d) {
394 t.Fatalf("expected %q, got %q", string(data), string(d))
395 }
396 case <-time.After(2 * time.Second):
397 t.Fatal("took too long to receive after unblackhole")
398 }
399 }
400
401 func TestServer_Shutdown(t *testing.T) {
402 lg := zaptest.NewLogger(t)
403 scheme := "unix"
404 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
405 defer func() {
406 os.RemoveAll(srcAddr)
407 os.RemoveAll(dstAddr)
408 }()
409 ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
410 defer ln.Close()
411
412 p := NewServer(ServerConfig{
413 Logger: lg,
414 From: url.URL{Scheme: scheme, Host: srcAddr},
415 To: url.URL{Scheme: scheme, Host: dstAddr},
416 })
417 <-p.Ready()
418 defer p.Close()
419
420 s, _ := p.(*server)
421 s.listener.Close()
422 time.Sleep(200 * time.Millisecond)
423
424 data := []byte("Hello World!")
425 send(t, data, scheme, srcAddr, transport.TLSInfo{})
426 if d := receive(t, ln); !bytes.Equal(d, data) {
427 t.Fatalf("expected %q, got %q", string(data), string(d))
428 }
429 }
430
431 func TestServer_ShutdownListener(t *testing.T) {
432 lg := zaptest.NewLogger(t)
433 scheme := "unix"
434 srcAddr, dstAddr := newUnixAddr(), newUnixAddr()
435 defer func() {
436 os.RemoveAll(srcAddr)
437 os.RemoveAll(dstAddr)
438 }()
439
440 ln := listen(t, scheme, dstAddr, transport.TLSInfo{})
441 defer ln.Close()
442
443 p := NewServer(ServerConfig{
444 Logger: lg,
445 From: url.URL{Scheme: scheme, Host: srcAddr},
446 To: url.URL{Scheme: scheme, Host: dstAddr},
447 })
448 <-p.Ready()
449 defer p.Close()
450
451
452 ln.Close()
453 time.Sleep(200 * time.Millisecond)
454
455 ln = listen(t, scheme, dstAddr, transport.TLSInfo{})
456 defer ln.Close()
457
458 data := []byte("Hello World!")
459 send(t, data, scheme, srcAddr, transport.TLSInfo{})
460 if d := receive(t, ln); !bytes.Equal(d, data) {
461 t.Fatalf("expected %q, got %q", string(data), string(d))
462 }
463 }
464
465 func TestServerHTTP_Insecure_DelayTx(t *testing.T) { testServerHTTP(t, false, true) }
466 func TestServerHTTP_Secure_DelayTx(t *testing.T) { testServerHTTP(t, true, true) }
467 func TestServerHTTP_Insecure_DelayRx(t *testing.T) { testServerHTTP(t, false, false) }
468 func TestServerHTTP_Secure_DelayRx(t *testing.T) { testServerHTTP(t, true, false) }
469 func testServerHTTP(t *testing.T, secure, delayTx bool) {
470 lg := zaptest.NewLogger(t)
471 scheme := "tcp"
472 ln1, ln2 := listen(t, scheme, "localhost:0", transport.TLSInfo{}), listen(t, scheme, "localhost:0", transport.TLSInfo{})
473 srcAddr, dstAddr := ln1.Addr().String(), ln2.Addr().String()
474 ln1.Close()
475 ln2.Close()
476
477 mux := http.NewServeMux()
478 mux.HandleFunc("/hello", func(w http.ResponseWriter, req *http.Request) {
479 d, err := ioutil.ReadAll(req.Body)
480 req.Body.Close()
481 if err != nil {
482 t.Fatal(err)
483 }
484 if _, err = w.Write([]byte(fmt.Sprintf("%q(confirmed)", string(d)))); err != nil {
485 t.Fatal(err)
486 }
487 })
488 tlsInfo := createTLSInfo(lg, secure)
489 var tlsConfig *tls.Config
490 if secure {
491 _, err := tlsInfo.ServerConfig()
492 if err != nil {
493 t.Fatal(err)
494 }
495 }
496 srv := &http.Server{
497 Addr: dstAddr,
498 Handler: mux,
499 TLSConfig: tlsConfig,
500 ErrorLog: log.New(ioutil.Discard, "net/http", 0),
501 }
502
503 donec := make(chan struct{})
504 defer func() {
505 srv.Close()
506 <-donec
507 }()
508 go func() {
509 if !secure {
510 srv.ListenAndServe()
511 } else {
512 srv.ListenAndServeTLS(tlsInfo.CertFile, tlsInfo.KeyFile)
513 }
514 defer close(donec)
515 }()
516 time.Sleep(200 * time.Millisecond)
517
518 cfg := ServerConfig{
519 Logger: lg,
520 From: url.URL{Scheme: scheme, Host: srcAddr},
521 To: url.URL{Scheme: scheme, Host: dstAddr},
522 }
523 if secure {
524 cfg.TLSInfo = tlsInfo
525 }
526 p := NewServer(cfg)
527 <-p.Ready()
528 defer func() {
529 lg.Info("closing Proxy server...")
530 p.Close()
531 lg.Info("closed Proxy server.")
532 }()
533
534 data := "Hello World!"
535
536 var resp *http.Response
537 var err error
538 now := time.Now()
539 if secure {
540 tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
541 assert.NoError(t, terr)
542 cli := &http.Client{Transport: tp}
543 resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
544 defer cli.CloseIdleConnections()
545 defer tp.CloseIdleConnections()
546 } else {
547 resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
548 defer http.DefaultClient.CloseIdleConnections()
549 }
550 assert.NoError(t, err)
551 d, err := ioutil.ReadAll(resp.Body)
552 if err != nil {
553 t.Fatal(err)
554 }
555 resp.Body.Close()
556 took1 := time.Since(now)
557 t.Logf("took %v with no latency", took1)
558
559 rs1 := string(d)
560 exp := fmt.Sprintf("%q(confirmed)", data)
561 if rs1 != exp {
562 t.Fatalf("got %q, expected %q", rs1, exp)
563 }
564
565 lat, rv := 100*time.Millisecond, 10*time.Millisecond
566 if delayTx {
567 p.DelayTx(lat, rv)
568 defer p.UndelayTx()
569 } else {
570 p.DelayRx(lat, rv)
571 defer p.UndelayRx()
572 }
573
574 now = time.Now()
575 if secure {
576 tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
577 if terr != nil {
578 t.Fatal(terr)
579 }
580 cli := &http.Client{Transport: tp}
581 resp, err = cli.Post("https://"+srcAddr+"/hello", "", strings.NewReader(data))
582 defer cli.CloseIdleConnections()
583 defer tp.CloseIdleConnections()
584 } else {
585 resp, err = http.Post("http://"+srcAddr+"/hello", "", strings.NewReader(data))
586 defer http.DefaultClient.CloseIdleConnections()
587 }
588 if err != nil {
589 t.Fatal(err)
590 }
591 d, err = ioutil.ReadAll(resp.Body)
592 if err != nil {
593 t.Fatal(err)
594 }
595 resp.Body.Close()
596 took2 := time.Since(now)
597 t.Logf("took %v with latency %v±%v", took2, lat, rv)
598
599 rs2 := string(d)
600 if rs2 != exp {
601 t.Fatalf("got %q, expected %q", rs2, exp)
602 }
603 if took1 > took2 {
604 t.Fatalf("expected took1 %v < took2 %v", took1, took2)
605 }
606 }
607
608 func newUnixAddr() string {
609 now := time.Now().UnixNano()
610 rand.Seed(now)
611 addr := fmt.Sprintf("%X%X.unix-conn", now, rand.Intn(35000))
612 os.RemoveAll(addr)
613 return addr
614 }
615
616 func listen(t *testing.T, scheme, addr string, tlsInfo transport.TLSInfo) (ln net.Listener) {
617 var err error
618 if !tlsInfo.Empty() {
619 ln, err = transport.NewListener(addr, scheme, &tlsInfo)
620 } else {
621 ln, err = net.Listen(scheme, addr)
622 }
623 if err != nil {
624 t.Fatal(err)
625 }
626 return ln
627 }
628
629 func send(t *testing.T, data []byte, scheme, addr string, tlsInfo transport.TLSInfo) {
630 var out net.Conn
631 var err error
632 if !tlsInfo.Empty() {
633 tp, terr := transport.NewTransport(tlsInfo, 3*time.Second)
634 if terr != nil {
635 t.Fatal(terr)
636 }
637 out, err = tp.DialContext(context.Background(), scheme, addr)
638 } else {
639 out, err = net.Dial(scheme, addr)
640 }
641 if err != nil {
642 t.Fatal(err)
643 }
644 if _, err = out.Write(data); err != nil {
645 t.Fatal(err)
646 }
647 if err = out.Close(); err != nil {
648 t.Fatal(err)
649 }
650 }
651
652 func receive(t *testing.T, ln net.Listener) (data []byte) {
653 buf := bytes.NewBuffer(make([]byte, 0, 1024))
654 for {
655 in, err := ln.Accept()
656 if err != nil {
657 t.Fatal(err)
658 }
659 var n int64
660 n, err = buf.ReadFrom(in)
661 if err != nil {
662 t.Fatal(err)
663 }
664 if n > 0 {
665 break
666 }
667 }
668 return buf.Bytes()
669 }
670
View as plain text