1
2
3
4
21
22 package transport
23
24 import (
25 "bufio"
26 "context"
27 "encoding/base64"
28 "fmt"
29 "io"
30 "net"
31 "net/http"
32 "net/url"
33 "testing"
34 "time"
35 )
36
37 const (
38 envTestAddr = "1.2.3.4:8080"
39 envProxyAddr = "2.3.4.5:7687"
40 )
41
42
43
44 func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() {
45 backHPFE := httpProxyFromEnvironment
46 httpProxyFromEnvironment = hpfe
47 return func() {
48 httpProxyFromEnvironment = backHPFE
49 }
50 }
51
52 type proxyServer struct {
53 t *testing.T
54 lis net.Listener
55 in net.Conn
56 out net.Conn
57
58 requestCheck func(*http.Request) error
59 }
60
61 func (p *proxyServer) run() {
62 in, err := p.lis.Accept()
63 if err != nil {
64 return
65 }
66 p.in = in
67
68 req, err := http.ReadRequest(bufio.NewReader(in))
69 if err != nil {
70 p.t.Errorf("failed to read CONNECT req: %v", err)
71 return
72 }
73 if err := p.requestCheck(req); err != nil {
74 resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
75 resp.Write(p.in)
76 p.in.Close()
77 p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
78 return
79 }
80
81 out, err := net.Dial("tcp", req.URL.Host)
82 if err != nil {
83 p.t.Errorf("failed to dial to server: %v", err)
84 return
85 }
86 resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
87 resp.Write(p.in)
88 p.out = out
89 go io.Copy(p.in, p.out)
90 go io.Copy(p.out, p.in)
91 }
92
93 func (p *proxyServer) stop() {
94 p.lis.Close()
95 if p.in != nil {
96 p.in.Close()
97 }
98 if p.out != nil {
99 p.out.Close()
100 }
101 }
102
103 func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
104 plis, err := net.Listen("tcp", "localhost:0")
105 if err != nil {
106 t.Fatalf("failed to listen: %v", err)
107 }
108 p := &proxyServer{
109 t: t,
110 lis: plis,
111 requestCheck: proxyReqCheck,
112 }
113 go p.run()
114 defer p.stop()
115
116 blis, err := net.Listen("tcp", "localhost:0")
117 if err != nil {
118 t.Fatalf("failed to listen: %v", err)
119 }
120
121 msg := []byte{4, 3, 5, 2}
122 recvBuf := make([]byte, len(msg))
123 done := make(chan error, 1)
124 go func() {
125 in, err := blis.Accept()
126 if err != nil {
127 done <- err
128 return
129 }
130 defer in.Close()
131 in.Read(recvBuf)
132 done <- nil
133 }()
134
135
136 hpfe := func(req *http.Request) (*url.URL, error) {
137 return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
138 }
139 defer overwrite(hpfe)()
140
141
142 ctx, cancel := context.WithTimeout(context.Background(), time.Second)
143 defer cancel()
144 c, err := proxyDial(ctx, blis.Addr().String(), "test")
145 if err != nil {
146 t.Fatalf("http connect Dial failed: %v", err)
147 }
148 defer c.Close()
149
150
151 c.Write(msg)
152 if err := <-done; err != nil {
153 t.Fatalf("failed to accept: %v", err)
154 }
155
156
157 if string(recvBuf) != string(msg) {
158 t.Fatalf("received msg: %v, want %v", recvBuf, msg)
159 }
160 }
161
162 func (s) TestHTTPConnect(t *testing.T) {
163 testHTTPConnect(t,
164 func(in *url.URL) *url.URL {
165 return in
166 },
167 func(req *http.Request) error {
168 if req.Method != http.MethodConnect {
169 return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
170 }
171 return nil
172 },
173 )
174 }
175
176 func (s) TestHTTPConnectBasicAuth(t *testing.T) {
177 const (
178 user = "notAUser"
179 password = "notAPassword"
180 )
181 testHTTPConnect(t,
182 func(in *url.URL) *url.URL {
183 in.User = url.UserPassword(user, password)
184 return in
185 },
186 func(req *http.Request) error {
187 if req.Method != http.MethodConnect {
188 return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
189 }
190 wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
191 if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
192 gotDecoded, _ := base64.StdEncoding.DecodeString(got)
193 wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
194 return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
195 }
196 return nil
197 },
198 )
199 }
200
201 func (s) TestMapAddressEnv(t *testing.T) {
202
203 hpfe := func(req *http.Request) (*url.URL, error) {
204 if req.URL.Host == envTestAddr {
205 return &url.URL{
206 Scheme: "https",
207 Host: envProxyAddr,
208 }, nil
209 }
210 return nil, nil
211 }
212 defer overwrite(hpfe)()
213
214
215 got, err := mapAddress(envTestAddr)
216 if err != nil {
217 t.Error(err)
218 }
219 if got.Host != envProxyAddr {
220 t.Errorf("want %v, got %v", envProxyAddr, got)
221 }
222 }
223
View as plain text