1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package proxy
16
17 import (
18 "context"
19 "fmt"
20 "io"
21 mrand "math/rand"
22 "net"
23 "net/http"
24 "net/url"
25 "strconv"
26 "strings"
27 "sync"
28 "time"
29
30 "go.etcd.io/etcd/client/pkg/v3/transport"
31
32 humanize "github.com/dustin/go-humanize"
33 "go.uber.org/zap"
34 )
35
36 var (
37 defaultDialTimeout = 3 * time.Second
38 defaultBufferSize = 48 * 1024
39 defaultRetryInterval = 10 * time.Millisecond
40 )
41
42
43
44
45
46 type Server interface {
47
48 From() string
49
50 To() string
51
52
53 Ready() <-chan struct{}
54
55 Done() <-chan struct{}
56
57 Error() <-chan error
58
59 Close() error
60
61
62 PauseAccept()
63
64 UnpauseAccept()
65
66
67
68 DelayAccept(latency, rv time.Duration)
69
70 UndelayAccept()
71
72
73 LatencyAccept() time.Duration
74
75
76
77 DelayTx(latency, rv time.Duration)
78
79 UndelayTx()
80
81 LatencyTx() time.Duration
82
83
84
85 DelayRx(latency, rv time.Duration)
86
87 UndelayRx()
88
89 LatencyRx() time.Duration
90
91
92
93 ModifyTx(f func(data []byte) []byte)
94
95 UnmodifyTx()
96
97
98
99 ModifyRx(f func(data []byte) []byte)
100
101 UnmodifyRx()
102
103
104
105
106 BlackholeTx()
107
108 UnblackholeTx()
109
110
111
112
113 BlackholeRx()
114
115 UnblackholeRx()
116
117
118 PauseTx()
119
120 UnpauseTx()
121
122
123 PauseRx()
124
125 UnpauseRx()
126
127
128 ResetListener() error
129 }
130
131
132 type ServerConfig struct {
133 Logger *zap.Logger
134 From url.URL
135 To url.URL
136 TLSInfo transport.TLSInfo
137 DialTimeout time.Duration
138 BufferSize int
139 RetryInterval time.Duration
140 }
141
142 type server struct {
143 lg *zap.Logger
144
145 from url.URL
146 fromPort int
147 to url.URL
148 toPort int
149
150 tlsInfo transport.TLSInfo
151 dialTimeout time.Duration
152
153 bufferSize int
154 retryInterval time.Duration
155
156 readyc chan struct{}
157 donec chan struct{}
158 errc chan error
159
160 closeOnce sync.Once
161 closeWg sync.WaitGroup
162
163 listenerMu sync.RWMutex
164 listener net.Listener
165
166 pauseAcceptMu sync.Mutex
167 pauseAcceptc chan struct{}
168
169 latencyAcceptMu sync.RWMutex
170 latencyAccept time.Duration
171
172 modifyTxMu sync.RWMutex
173 modifyTx func(data []byte) []byte
174
175 modifyRxMu sync.RWMutex
176 modifyRx func(data []byte) []byte
177
178 pauseTxMu sync.Mutex
179 pauseTxc chan struct{}
180
181 pauseRxMu sync.Mutex
182 pauseRxc chan struct{}
183
184 latencyTxMu sync.RWMutex
185 latencyTx time.Duration
186
187 latencyRxMu sync.RWMutex
188 latencyRx time.Duration
189 }
190
191
192
193 func NewServer(cfg ServerConfig) Server {
194 s := &server{
195 lg: cfg.Logger,
196
197 from: cfg.From,
198 to: cfg.To,
199
200 tlsInfo: cfg.TLSInfo,
201 dialTimeout: cfg.DialTimeout,
202
203 bufferSize: cfg.BufferSize,
204 retryInterval: cfg.RetryInterval,
205
206 readyc: make(chan struct{}),
207 donec: make(chan struct{}),
208 errc: make(chan error, 16),
209
210 pauseAcceptc: make(chan struct{}),
211 pauseTxc: make(chan struct{}),
212 pauseRxc: make(chan struct{}),
213 }
214
215 _, fromPort, err := net.SplitHostPort(cfg.From.Host)
216 if err == nil {
217 s.fromPort, _ = strconv.Atoi(fromPort)
218 }
219 var toPort string
220 _, toPort, err = net.SplitHostPort(cfg.To.Host)
221 if err == nil {
222 s.toPort, _ = strconv.Atoi(toPort)
223 }
224
225 if s.dialTimeout == 0 {
226 s.dialTimeout = defaultDialTimeout
227 }
228 if s.bufferSize == 0 {
229 s.bufferSize = defaultBufferSize
230 }
231 if s.retryInterval == 0 {
232 s.retryInterval = defaultRetryInterval
233 }
234
235 close(s.pauseAcceptc)
236 close(s.pauseTxc)
237 close(s.pauseRxc)
238
239 if strings.HasPrefix(s.from.Scheme, "http") {
240 s.from.Scheme = "tcp"
241 }
242 if strings.HasPrefix(s.to.Scheme, "http") {
243 s.to.Scheme = "tcp"
244 }
245
246 addr := fmt.Sprintf(":%d", s.fromPort)
247 if s.fromPort == 0 {
248 addr = s.from.Host
249 }
250
251 var ln net.Listener
252 if !s.tlsInfo.Empty() {
253 ln, err = transport.NewListener(addr, s.from.Scheme, &s.tlsInfo)
254 } else {
255 ln, err = net.Listen(s.from.Scheme, addr)
256 }
257 if err != nil {
258 s.errc <- err
259 s.Close()
260 return s
261 }
262 s.listener = ln
263
264 s.closeWg.Add(1)
265 go s.listenAndServe()
266
267 s.lg.Info("started proxying", zap.String("from", s.From()), zap.String("to", s.To()))
268 return s
269 }
270
271 func (s *server) From() string {
272 return fmt.Sprintf("%s://%s", s.from.Scheme, s.from.Host)
273 }
274
275 func (s *server) To() string {
276 return fmt.Sprintf("%s://%s", s.to.Scheme, s.to.Host)
277 }
278
279
280
281
282
283
284 func (s *server) listenAndServe() {
285 defer s.closeWg.Done()
286
287 ctx := context.Background()
288 s.lg.Info("proxy is listening on", zap.String("from", s.From()))
289 close(s.readyc)
290
291 for {
292 s.pauseAcceptMu.Lock()
293 pausec := s.pauseAcceptc
294 s.pauseAcceptMu.Unlock()
295 select {
296 case <-pausec:
297 case <-s.donec:
298 return
299 }
300
301 s.latencyAcceptMu.RLock()
302 lat := s.latencyAccept
303 s.latencyAcceptMu.RUnlock()
304 if lat > 0 {
305 select {
306 case <-time.After(lat):
307 case <-s.donec:
308 return
309 }
310 }
311
312 s.listenerMu.RLock()
313 ln := s.listener
314 s.listenerMu.RUnlock()
315
316 in, err := ln.Accept()
317 if err != nil {
318 select {
319 case s.errc <- err:
320 select {
321 case <-s.donec:
322 return
323 default:
324 }
325 case <-s.donec:
326 return
327 }
328 s.lg.Debug("listener accept error", zap.Error(err))
329
330 if strings.HasSuffix(err.Error(), "use of closed network connection") {
331 select {
332 case <-time.After(s.retryInterval):
333 case <-s.donec:
334 return
335 }
336 s.lg.Debug("listener is closed; retry listening on", zap.String("from", s.From()))
337
338 if err = s.ResetListener(); err != nil {
339 select {
340 case s.errc <- err:
341 select {
342 case <-s.donec:
343 return
344 default:
345 }
346 case <-s.donec:
347 return
348 }
349 s.lg.Warn("failed to reset listener", zap.Error(err))
350 }
351 }
352
353 continue
354 }
355
356 var out net.Conn
357 if !s.tlsInfo.Empty() {
358 var tp *http.Transport
359 tp, err = transport.NewTransport(s.tlsInfo, s.dialTimeout)
360 if err != nil {
361 select {
362 case s.errc <- err:
363 select {
364 case <-s.donec:
365 return
366 default:
367 }
368 case <-s.donec:
369 return
370 }
371 continue
372 }
373 out, err = tp.DialContext(ctx, s.to.Scheme, s.to.Host)
374 } else {
375 out, err = net.Dial(s.to.Scheme, s.to.Host)
376 }
377 if err != nil {
378 select {
379 case s.errc <- err:
380 select {
381 case <-s.donec:
382 return
383 default:
384 }
385 case <-s.donec:
386 return
387 }
388 s.lg.Debug("failed to dial", zap.Error(err))
389 continue
390 }
391
392 s.closeWg.Add(2)
393 go func() {
394 defer s.closeWg.Done()
395
396 s.transmit(out, in)
397 out.Close()
398 in.Close()
399 }()
400 go func() {
401 defer s.closeWg.Done()
402
403 s.receive(in, out)
404 in.Close()
405 out.Close()
406 }()
407 }
408 }
409
410 func (s *server) transmit(dst io.Writer, src io.Reader) {
411 s.ioCopy(dst, src, proxyTx)
412 }
413
414 func (s *server) receive(dst io.Writer, src io.Reader) {
415 s.ioCopy(dst, src, proxyRx)
416 }
417
418 type proxyType uint8
419
420 const (
421 proxyTx proxyType = iota
422 proxyRx
423 )
424
425 func (s *server) ioCopy(dst io.Writer, src io.Reader, ptype proxyType) {
426 buf := make([]byte, s.bufferSize)
427 for {
428 nr1, err := src.Read(buf)
429 if err != nil {
430 if err == io.EOF {
431 return
432 }
433
434 if strings.HasSuffix(err.Error(), "read: connection reset by peer") {
435 return
436 }
437 if strings.HasSuffix(err.Error(), "use of closed network connection") {
438 return
439 }
440 select {
441 case s.errc <- err:
442 select {
443 case <-s.donec:
444 return
445 default:
446 }
447 case <-s.donec:
448 return
449 }
450 s.lg.Debug("failed to read", zap.Error(err))
451 return
452 }
453 if nr1 == 0 {
454 return
455 }
456 data := buf[:nr1]
457
458
459 switch ptype {
460 case proxyTx:
461 s.modifyTxMu.RLock()
462 if s.modifyTx != nil {
463 data = s.modifyTx(data)
464 }
465 s.modifyTxMu.RUnlock()
466 case proxyRx:
467 s.modifyRxMu.RLock()
468 if s.modifyRx != nil {
469 data = s.modifyRx(data)
470 }
471 s.modifyRxMu.RUnlock()
472 default:
473 panic("unknown proxy type")
474 }
475 nr2 := len(data)
476 switch ptype {
477 case proxyTx:
478 s.lg.Debug(
479 "modified tx",
480 zap.String("data-received", humanize.Bytes(uint64(nr1))),
481 zap.String("data-modified", humanize.Bytes(uint64(nr2))),
482 zap.String("from", s.From()),
483 zap.String("to", s.To()),
484 )
485 case proxyRx:
486 s.lg.Debug(
487 "modified rx",
488 zap.String("data-received", humanize.Bytes(uint64(nr1))),
489 zap.String("data-modified", humanize.Bytes(uint64(nr2))),
490 zap.String("from", s.To()),
491 zap.String("to", s.From()),
492 )
493 default:
494 panic("unknown proxy type")
495 }
496
497
498 var pausec chan struct{}
499 switch ptype {
500 case proxyTx:
501 s.pauseTxMu.Lock()
502 pausec = s.pauseTxc
503 s.pauseTxMu.Unlock()
504 case proxyRx:
505 s.pauseRxMu.Lock()
506 pausec = s.pauseRxc
507 s.pauseRxMu.Unlock()
508 default:
509 panic("unknown proxy type")
510 }
511 select {
512 case <-pausec:
513 case <-s.donec:
514 return
515 }
516
517
518 if nr2 == 0 {
519 continue
520 }
521
522
523 var lat time.Duration
524 switch ptype {
525 case proxyTx:
526 s.latencyTxMu.RLock()
527 lat = s.latencyTx
528 s.latencyTxMu.RUnlock()
529 case proxyRx:
530 s.latencyRxMu.RLock()
531 lat = s.latencyRx
532 s.latencyRxMu.RUnlock()
533 default:
534 panic("unknown proxy type")
535 }
536 if lat > 0 {
537 select {
538 case <-time.After(lat):
539 case <-s.donec:
540 return
541 }
542 }
543
544
545 var nw int
546 nw, err = dst.Write(data)
547 if err != nil {
548 if err == io.EOF {
549 return
550 }
551 select {
552 case s.errc <- err:
553 select {
554 case <-s.donec:
555 return
556 default:
557 }
558 case <-s.donec:
559 return
560 }
561 switch ptype {
562 case proxyTx:
563 s.lg.Debug("write fail on tx", zap.Error(err))
564 case proxyRx:
565 s.lg.Debug("write fail on rx", zap.Error(err))
566 default:
567 panic("unknown proxy type")
568 }
569 return
570 }
571
572 if nr2 != nw {
573 select {
574 case s.errc <- io.ErrShortWrite:
575 select {
576 case <-s.donec:
577 return
578 default:
579 }
580 case <-s.donec:
581 return
582 }
583 switch ptype {
584 case proxyTx:
585 s.lg.Debug(
586 "write fail on tx; read/write bytes are different",
587 zap.Int("read-bytes", nr1),
588 zap.Int("write-bytes", nw),
589 zap.Error(io.ErrShortWrite),
590 )
591 case proxyRx:
592 s.lg.Debug(
593 "write fail on rx; read/write bytes are different",
594 zap.Int("read-bytes", nr1),
595 zap.Int("write-bytes", nw),
596 zap.Error(io.ErrShortWrite),
597 )
598 default:
599 panic("unknown proxy type")
600 }
601 return
602 }
603
604 switch ptype {
605 case proxyTx:
606 s.lg.Debug(
607 "transmitted",
608 zap.String("data-size", humanize.Bytes(uint64(nr1))),
609 zap.String("from", s.From()),
610 zap.String("to", s.To()),
611 )
612 case proxyRx:
613 s.lg.Debug(
614 "received",
615 zap.String("data-size", humanize.Bytes(uint64(nr1))),
616 zap.String("from", s.To()),
617 zap.String("to", s.From()),
618 )
619 default:
620 panic("unknown proxy type")
621 }
622 }
623 }
624
625 func (s *server) Ready() <-chan struct{} { return s.readyc }
626 func (s *server) Done() <-chan struct{} { return s.donec }
627 func (s *server) Error() <-chan error { return s.errc }
628 func (s *server) Close() (err error) {
629 s.closeOnce.Do(func() {
630 close(s.donec)
631 s.listenerMu.Lock()
632 if s.listener != nil {
633 err = s.listener.Close()
634 s.lg.Info(
635 "closed proxy listener",
636 zap.String("from", s.From()),
637 zap.String("to", s.To()),
638 )
639 }
640 s.lg.Sync()
641 s.listenerMu.Unlock()
642 })
643 s.closeWg.Wait()
644 return err
645 }
646
647 func (s *server) PauseAccept() {
648 s.pauseAcceptMu.Lock()
649 s.pauseAcceptc = make(chan struct{})
650 s.pauseAcceptMu.Unlock()
651
652 s.lg.Info(
653 "paused accept",
654 zap.String("from", s.From()),
655 zap.String("to", s.To()),
656 )
657 }
658
659 func (s *server) UnpauseAccept() {
660 s.pauseAcceptMu.Lock()
661 select {
662 case <-s.pauseAcceptc:
663 case <-s.donec:
664 s.pauseAcceptMu.Unlock()
665 return
666 default:
667 close(s.pauseAcceptc)
668 }
669 s.pauseAcceptMu.Unlock()
670
671 s.lg.Info(
672 "unpaused accept",
673 zap.String("from", s.From()),
674 zap.String("to", s.To()),
675 )
676 }
677
678 func (s *server) DelayAccept(latency, rv time.Duration) {
679 if latency <= 0 {
680 return
681 }
682 d := computeLatency(latency, rv)
683 s.latencyAcceptMu.Lock()
684 s.latencyAccept = d
685 s.latencyAcceptMu.Unlock()
686
687 s.lg.Info(
688 "set accept latency",
689 zap.Duration("latency", d),
690 zap.Duration("given-latency", latency),
691 zap.Duration("given-latency-random-variable", rv),
692 zap.String("from", s.From()),
693 zap.String("to", s.To()),
694 )
695 }
696
697 func (s *server) UndelayAccept() {
698 s.latencyAcceptMu.Lock()
699 d := s.latencyAccept
700 s.latencyAccept = 0
701 s.latencyAcceptMu.Unlock()
702
703 s.lg.Info(
704 "removed accept latency",
705 zap.Duration("latency", d),
706 zap.String("from", s.From()),
707 zap.String("to", s.To()),
708 )
709 }
710
711 func (s *server) LatencyAccept() time.Duration {
712 s.latencyAcceptMu.RLock()
713 d := s.latencyAccept
714 s.latencyAcceptMu.RUnlock()
715 return d
716 }
717
718 func (s *server) DelayTx(latency, rv time.Duration) {
719 if latency <= 0 {
720 return
721 }
722 d := computeLatency(latency, rv)
723 s.latencyTxMu.Lock()
724 s.latencyTx = d
725 s.latencyTxMu.Unlock()
726
727 s.lg.Info(
728 "set transmit latency",
729 zap.Duration("latency", d),
730 zap.Duration("given-latency", latency),
731 zap.Duration("given-latency-random-variable", rv),
732 zap.String("from", s.From()),
733 zap.String("to", s.To()),
734 )
735 }
736
737 func (s *server) UndelayTx() {
738 s.latencyTxMu.Lock()
739 d := s.latencyTx
740 s.latencyTx = 0
741 s.latencyTxMu.Unlock()
742
743 s.lg.Info(
744 "removed transmit latency",
745 zap.Duration("latency", d),
746 zap.String("from", s.From()),
747 zap.String("to", s.To()),
748 )
749 }
750
751 func (s *server) LatencyTx() time.Duration {
752 s.latencyTxMu.RLock()
753 d := s.latencyTx
754 s.latencyTxMu.RUnlock()
755 return d
756 }
757
758 func (s *server) DelayRx(latency, rv time.Duration) {
759 if latency <= 0 {
760 return
761 }
762 d := computeLatency(latency, rv)
763 s.latencyRxMu.Lock()
764 s.latencyRx = d
765 s.latencyRxMu.Unlock()
766
767 s.lg.Info(
768 "set receive latency",
769 zap.Duration("latency", d),
770 zap.Duration("given-latency", latency),
771 zap.Duration("given-latency-random-variable", rv),
772 zap.String("from", s.To()),
773 zap.String("to", s.From()),
774 )
775 }
776
777 func (s *server) UndelayRx() {
778 s.latencyRxMu.Lock()
779 d := s.latencyRx
780 s.latencyRx = 0
781 s.latencyRxMu.Unlock()
782
783 s.lg.Info(
784 "removed receive latency",
785 zap.Duration("latency", d),
786 zap.String("from", s.To()),
787 zap.String("to", s.From()),
788 )
789 }
790
791 func (s *server) LatencyRx() time.Duration {
792 s.latencyRxMu.RLock()
793 d := s.latencyRx
794 s.latencyRxMu.RUnlock()
795 return d
796 }
797
798 func computeLatency(lat, rv time.Duration) time.Duration {
799 if rv == 0 {
800 return lat
801 }
802 if rv < 0 {
803 rv *= -1
804 }
805 if rv > lat {
806 rv = lat / 10
807 }
808 now := time.Now()
809 mrand.Seed(int64(now.Nanosecond()))
810 sign := 1
811 if now.Second()%2 == 0 {
812 sign = -1
813 }
814 return lat + time.Duration(int64(sign)*mrand.Int63n(rv.Nanoseconds()))
815 }
816
817 func (s *server) ModifyTx(f func([]byte) []byte) {
818 s.modifyTxMu.Lock()
819 s.modifyTx = f
820 s.modifyTxMu.Unlock()
821
822 s.lg.Info(
823 "modifying tx",
824 zap.String("from", s.From()),
825 zap.String("to", s.To()),
826 )
827 }
828
829 func (s *server) UnmodifyTx() {
830 s.modifyTxMu.Lock()
831 s.modifyTx = nil
832 s.modifyTxMu.Unlock()
833
834 s.lg.Info(
835 "unmodifyed tx",
836 zap.String("from", s.From()),
837 zap.String("to", s.To()),
838 )
839 }
840
841 func (s *server) ModifyRx(f func([]byte) []byte) {
842 s.modifyRxMu.Lock()
843 s.modifyRx = f
844 s.modifyRxMu.Unlock()
845 s.lg.Info(
846 "modifying rx",
847 zap.String("from", s.To()),
848 zap.String("to", s.From()),
849 )
850 }
851
852 func (s *server) UnmodifyRx() {
853 s.modifyRxMu.Lock()
854 s.modifyRx = nil
855 s.modifyRxMu.Unlock()
856
857 s.lg.Info(
858 "unmodifyed rx",
859 zap.String("from", s.To()),
860 zap.String("to", s.From()),
861 )
862 }
863
864 func (s *server) BlackholeTx() {
865 s.ModifyTx(func([]byte) []byte { return nil })
866 s.lg.Info(
867 "blackholed tx",
868 zap.String("from", s.From()),
869 zap.String("to", s.To()),
870 )
871 }
872
873 func (s *server) UnblackholeTx() {
874 s.UnmodifyTx()
875 s.lg.Info(
876 "unblackholed tx",
877 zap.String("from", s.From()),
878 zap.String("to", s.To()),
879 )
880 }
881
882 func (s *server) BlackholeRx() {
883 s.ModifyRx(func([]byte) []byte { return nil })
884 s.lg.Info(
885 "blackholed rx",
886 zap.String("from", s.To()),
887 zap.String("to", s.From()),
888 )
889 }
890
891 func (s *server) UnblackholeRx() {
892 s.UnmodifyRx()
893 s.lg.Info(
894 "unblackholed rx",
895 zap.String("from", s.To()),
896 zap.String("to", s.From()),
897 )
898 }
899
900 func (s *server) PauseTx() {
901 s.pauseTxMu.Lock()
902 s.pauseTxc = make(chan struct{})
903 s.pauseTxMu.Unlock()
904
905 s.lg.Info(
906 "paused tx",
907 zap.String("from", s.From()),
908 zap.String("to", s.To()),
909 )
910 }
911
912 func (s *server) UnpauseTx() {
913 s.pauseTxMu.Lock()
914 select {
915 case <-s.pauseTxc:
916 case <-s.donec:
917 s.pauseTxMu.Unlock()
918 return
919 default:
920 close(s.pauseTxc)
921 }
922 s.pauseTxMu.Unlock()
923
924 s.lg.Info(
925 "unpaused tx",
926 zap.String("from", s.From()),
927 zap.String("to", s.To()),
928 )
929 }
930
931 func (s *server) PauseRx() {
932 s.pauseRxMu.Lock()
933 s.pauseRxc = make(chan struct{})
934 s.pauseRxMu.Unlock()
935
936 s.lg.Info(
937 "paused rx",
938 zap.String("from", s.To()),
939 zap.String("to", s.From()),
940 )
941 }
942
943 func (s *server) UnpauseRx() {
944 s.pauseRxMu.Lock()
945 select {
946 case <-s.pauseRxc:
947 case <-s.donec:
948 s.pauseRxMu.Unlock()
949 return
950 default:
951 close(s.pauseRxc)
952 }
953 s.pauseRxMu.Unlock()
954
955 s.lg.Info(
956 "unpaused rx",
957 zap.String("from", s.To()),
958 zap.String("to", s.From()),
959 )
960 }
961
962 func (s *server) ResetListener() error {
963 s.listenerMu.Lock()
964 defer s.listenerMu.Unlock()
965
966 if err := s.listener.Close(); err != nil {
967
968 if !strings.HasSuffix(err.Error(), "use of closed network connection") {
969 return err
970 }
971 }
972
973 var ln net.Listener
974 var err error
975 if !s.tlsInfo.Empty() {
976 ln, err = transport.NewListener(s.from.Host, s.from.Scheme, &s.tlsInfo)
977 } else {
978 ln, err = net.Listen(s.from.Scheme, s.from.Host)
979 }
980 if err != nil {
981 return err
982 }
983 s.listener = ln
984
985 s.lg.Info(
986 "reset listener on",
987 zap.String("from", s.From()),
988 )
989 return nil
990 }
991
View as plain text