1
16
17
18
19
20 package nettesting
21
22 import (
23 "io"
24 "net"
25 "net/http"
26 "net/http/httputil"
27 "sync"
28 "testing"
29
30 "github.com/onsi/ginkgo/v2"
31 )
32
33 type TB interface {
34 Logf(format string, args ...any)
35 }
36
37
38
39
40 func NewHTTPProxyHandler(t TB, hook func(req *http.Request) bool) *HTTPProxyHandler {
41
42
43 switch t.(type) {
44 case testing.TB, ginkgo.GinkgoTInterface:
45 default:
46 panic("t is not a known test interface")
47 }
48 h := &HTTPProxyHandler{
49 hook: hook,
50 httpProxy: httputil.ReverseProxy{
51 Director: func(req *http.Request) {
52 req.URL.Scheme = "http"
53 req.URL.Host = req.Host
54 },
55 },
56 t: t,
57 }
58 return h
59 }
60
61
62
63 type HTTPProxyHandler struct {
64 handlerDone sync.WaitGroup
65 hook func(r *http.Request) bool
66
67 httpProxy httputil.ReverseProxy
68 t TB
69 }
70
71
72 func (h *HTTPProxyHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
73 h.handlerDone.Add(1)
74 defer h.handlerDone.Done()
75
76 if h.hook != nil {
77 if ok := h.hook(req); !ok {
78 rw.WriteHeader(http.StatusInternalServerError)
79 return
80 }
81 }
82
83 b, err := httputil.DumpRequest(req, false)
84 if err != nil {
85 h.t.Logf("Failed to dump request, host=%s: %v", req.Host, err)
86 } else {
87 h.t.Logf("Proxy Request: %s", string(b))
88 }
89
90 if req.Method != http.MethodConnect {
91 h.httpProxy.ServeHTTP(rw, req)
92 return
93 }
94
95
96
97 sconn, err := net.Dial("tcp", req.Host)
98 if err != nil {
99 h.t.Logf("Failed to dial proxy backend, host=%s: %v", req.Host, err)
100 rw.WriteHeader(http.StatusInternalServerError)
101 return
102 }
103 defer sconn.Close()
104
105 hj, ok := rw.(http.Hijacker)
106 if !ok {
107 h.t.Logf("Can't switch protocols using non-Hijacker ResponseWriter: type=%T, host=%s", rw, req.Host)
108 rw.WriteHeader(http.StatusInternalServerError)
109 return
110 }
111
112 rw.WriteHeader(http.StatusOK)
113
114 conn, brw, err := hj.Hijack()
115 if err != nil {
116 h.t.Logf("Failed to hijack client connection, host=%s: %v", req.Host, err)
117 return
118 }
119 defer conn.Close()
120
121 if err := brw.Flush(); err != nil {
122 h.t.Logf("Failed to flush pending writes to client, host=%s: %v", req.Host, err)
123 return
124 }
125 if _, err := io.Copy(sconn, io.LimitReader(brw, int64(brw.Reader.Buffered()))); err != nil {
126 h.t.Logf("Failed to flush buffered reads to server, host=%s: %v", req.Host, err)
127 return
128 }
129
130 var wg sync.WaitGroup
131 wg.Add(2)
132
133 go func() {
134 defer wg.Done()
135 defer h.t.Logf("Server read close, host=%s", req.Host)
136 io.Copy(conn, sconn)
137 }()
138 go func() {
139 defer wg.Done()
140 defer h.t.Logf("Server write close, host=%s", req.Host)
141 io.Copy(sconn, conn)
142 }()
143
144 wg.Wait()
145 h.t.Logf("Done handling CONNECT request, host=%s", req.Host)
146 }
147
148 func (h *HTTPProxyHandler) Wait() {
149 h.handlerDone.Wait()
150 }
151
View as plain text