1
16
17 package proxy
18
19 import (
20 "bufio"
21 "bytes"
22 "fmt"
23 "io"
24 "log"
25 "net"
26 "net/http"
27 "net/http/httputil"
28 "net/url"
29 "os"
30 "strings"
31 "time"
32
33 "k8s.io/apimachinery/pkg/api/errors"
34 "k8s.io/apimachinery/pkg/util/httpstream"
35 utilnet "k8s.io/apimachinery/pkg/util/net"
36 utilruntime "k8s.io/apimachinery/pkg/util/runtime"
37
38 "github.com/mxk/go-flowrate/flowrate"
39
40 "k8s.io/klog/v2"
41 )
42
43
44
45
46
47 type UpgradeRequestRoundTripper interface {
48 http.RoundTripper
49
50
51
52
53 WrapRequest(*http.Request) (*http.Request, error)
54 }
55
56
57 type UpgradeAwareHandler struct {
58
59 UpgradeRequired bool
60
61
62 Location *url.URL
63
64 AppendLocationPath bool
65
66 Transport http.RoundTripper
67
68
69 UpgradeTransport UpgradeRequestRoundTripper
70
71 WrapTransport bool
72
73 UseRequestLocation bool
74
75
76
77
78
79 UseLocationHost bool
80
81 FlushInterval time.Duration
82
83 MaxBytesPerSec int64
84
85 Responder ErrorResponder
86
87 RejectForwardingRedirects bool
88 }
89
90 const defaultFlushInterval = 200 * time.Millisecond
91
92
93
94 type ErrorResponder interface {
95 Error(w http.ResponseWriter, req *http.Request, err error)
96 }
97
98
99
100 type SimpleErrorResponder interface {
101 Error(err error)
102 }
103
104 func NewErrorResponder(r SimpleErrorResponder) ErrorResponder {
105 return simpleResponder{r}
106 }
107
108 type simpleResponder struct {
109 responder SimpleErrorResponder
110 }
111
112 func (r simpleResponder) Error(w http.ResponseWriter, req *http.Request, err error) {
113 r.responder.Error(err)
114 }
115
116
117 type upgradeRequestRoundTripper struct {
118 http.RoundTripper
119 upgrader http.RoundTripper
120 }
121
122 var (
123 _ UpgradeRequestRoundTripper = &upgradeRequestRoundTripper{}
124 _ utilnet.RoundTripperWrapper = &upgradeRequestRoundTripper{}
125 )
126
127
128 func (rt *upgradeRequestRoundTripper) WrappedRoundTripper() http.RoundTripper {
129 return rt.RoundTripper
130 }
131
132
133
134 func (rt *upgradeRequestRoundTripper) WrapRequest(req *http.Request) (*http.Request, error) {
135 resp, err := rt.upgrader.RoundTrip(req)
136 if err != nil {
137 return nil, err
138 }
139 return resp.Request, nil
140 }
141
142
143
144 type onewayRoundTripper struct{}
145
146
147 func (onewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
148 return &http.Response{
149 Status: "200 OK",
150 StatusCode: http.StatusOK,
151 Body: io.NopCloser(&bytes.Buffer{}),
152 Request: req,
153 }, nil
154 }
155
156
157
158 var MirrorRequest http.RoundTripper = onewayRoundTripper{}
159
160
161
162
163 func NewUpgradeRequestRoundTripper(connection, request http.RoundTripper) UpgradeRequestRoundTripper {
164 return &upgradeRequestRoundTripper{
165 RoundTripper: connection,
166 upgrader: request,
167 }
168 }
169
170
171 func normalizeLocation(location *url.URL) *url.URL {
172 normalized, _ := url.Parse(location.String())
173 if len(normalized.Scheme) == 0 {
174 normalized.Scheme = "http"
175 }
176 return normalized
177 }
178
179
180
181 func NewUpgradeAwareHandler(location *url.URL, transport http.RoundTripper, wrapTransport, upgradeRequired bool, responder ErrorResponder) *UpgradeAwareHandler {
182 return &UpgradeAwareHandler{
183 Location: normalizeLocation(location),
184 Transport: transport,
185 WrapTransport: wrapTransport,
186 UpgradeRequired: upgradeRequired,
187 FlushInterval: defaultFlushInterval,
188 Responder: responder,
189 }
190 }
191
192 func proxyRedirectsforRootPath(path string, w http.ResponseWriter, req *http.Request) bool {
193 redirect := false
194 method := req.Method
195
196
197
198
199
200 if len(path) == 0 && (method == http.MethodGet || method == http.MethodHead) {
201 var queryPart string
202 if len(req.URL.RawQuery) > 0 {
203 queryPart = "?" + req.URL.RawQuery
204 }
205 w.Header().Set("Location", req.URL.Path+"/"+queryPart)
206 w.WriteHeader(http.StatusMovedPermanently)
207 redirect = true
208 }
209 return redirect
210 }
211
212
213 func (h *UpgradeAwareHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
214 if h.tryUpgrade(w, req) {
215 return
216 }
217 if h.UpgradeRequired {
218 h.Responder.Error(w, req, errors.NewBadRequest("Upgrade request required"))
219 return
220 }
221
222 loc := *h.Location
223 loc.RawQuery = req.URL.RawQuery
224
225
226
227 if !strings.HasSuffix(loc.Path, "/") && strings.HasSuffix(req.URL.Path, "/") {
228 loc.Path += "/"
229 }
230
231 proxyRedirect := proxyRedirectsforRootPath(loc.Path, w, req)
232 if proxyRedirect {
233 return
234 }
235
236 if h.Transport == nil || h.WrapTransport {
237 h.Transport = h.defaultProxyTransport(req.URL, h.Transport)
238 }
239
240
241 newReq := req.WithContext(req.Context())
242 newReq.Header = utilnet.CloneHeader(req.Header)
243 if !h.UseRequestLocation {
244 newReq.URL = &loc
245 }
246 if h.UseLocationHost {
247
248
249 newReq.Host = h.Location.Host
250 }
251
252
253 reverseProxyLocation := &url.URL{Scheme: h.Location.Scheme, Host: h.Location.Host}
254 if h.AppendLocationPath {
255 reverseProxyLocation.Path = h.Location.Path
256 }
257
258 proxy := httputil.NewSingleHostReverseProxy(reverseProxyLocation)
259 proxy.Transport = h.Transport
260 proxy.FlushInterval = h.FlushInterval
261 proxy.ErrorLog = log.New(noSuppressPanicError{}, "", log.LstdFlags)
262 if h.RejectForwardingRedirects {
263 oldModifyResponse := proxy.ModifyResponse
264 proxy.ModifyResponse = func(response *http.Response) error {
265 code := response.StatusCode
266 if code >= 300 && code <= 399 && len(response.Header.Get("Location")) > 0 {
267
268 response.Body.Close()
269 msg := "the backend attempted to redirect this request, which is not permitted"
270
271 *response = http.Response{
272 StatusCode: http.StatusBadGateway,
273 Status: fmt.Sprintf("%d %s", response.StatusCode, http.StatusText(response.StatusCode)),
274 Body: io.NopCloser(strings.NewReader(msg)),
275 ContentLength: int64(len(msg)),
276 }
277 } else {
278 if oldModifyResponse != nil {
279 if err := oldModifyResponse(response); err != nil {
280 return err
281 }
282 }
283 }
284 return nil
285 }
286 }
287 if h.Responder != nil {
288
289
290
291 proxy.ErrorHandler = h.Responder.Error
292 }
293 proxy.ServeHTTP(w, newReq)
294 }
295
296 type noSuppressPanicError struct{}
297
298 func (noSuppressPanicError) Write(p []byte) (n int, err error) {
299
300
301
302 if strings.Contains(string(p), "suppressing panic") {
303 return len(p), nil
304 }
305 return os.Stderr.Write(p)
306 }
307
308
309 func (h *UpgradeAwareHandler) tryUpgrade(w http.ResponseWriter, req *http.Request) bool {
310 if !httpstream.IsUpgradeRequest(req) {
311 klog.V(6).Infof("Request was not an upgrade")
312 return false
313 }
314
315 var (
316 backendConn net.Conn
317 rawResponse []byte
318 err error
319 )
320
321 location := *h.Location
322 if h.UseRequestLocation {
323 location = *req.URL
324 location.Scheme = h.Location.Scheme
325 location.Host = h.Location.Host
326 if h.AppendLocationPath {
327 location.Path = singleJoiningSlash(h.Location.Path, location.Path)
328 }
329 }
330
331 clone := utilnet.CloneRequest(req)
332
333
334 utilnet.AppendForwardedForHeader(clone)
335 klog.V(6).Infof("Connecting to backend proxy (direct dial) %s\n Headers: %v", &location, clone.Header)
336 if h.UseLocationHost {
337 clone.Host = h.Location.Host
338 }
339 clone.URL = &location
340 klog.V(6).Infof("UpgradeAwareProxy: dialing for SPDY upgrade with headers: %v", clone.Header)
341 backendConn, err = h.DialForUpgrade(clone)
342 if err != nil {
343 klog.V(6).Infof("Proxy connection error: %v", err)
344 h.Responder.Error(w, req, err)
345 return true
346 }
347 defer backendConn.Close()
348
349
350 backendHTTPResponse, headerBytes, err := getResponse(io.MultiReader(bytes.NewReader(rawResponse), backendConn))
351 if err != nil {
352 klog.V(6).Infof("Proxy connection error: %v", err)
353 h.Responder.Error(w, req, err)
354 return true
355 }
356 if len(headerBytes) > len(rawResponse) {
357
358 rawResponse = headerBytes
359 }
360
361
362
363
364 if backendHTTPResponse.StatusCode != http.StatusSwitchingProtocols && backendHTTPResponse.StatusCode < 400 {
365 err := fmt.Errorf("invalid upgrade response: status code %d", backendHTTPResponse.StatusCode)
366 klog.Errorf("Proxy upgrade error: %v", err)
367 h.Responder.Error(w, req, err)
368 return true
369 }
370
371
372
373 requestHijacker, ok := w.(http.Hijacker)
374 if !ok {
375 klog.Errorf("Unable to hijack response writer: %T", w)
376 h.Responder.Error(w, req, fmt.Errorf("request connection cannot be hijacked: %T", w))
377 return true
378 }
379 requestHijackedConn, _, err := requestHijacker.Hijack()
380 if err != nil {
381 klog.Errorf("Unable to hijack response: %v", err)
382 h.Responder.Error(w, req, fmt.Errorf("error hijacking connection: %v", err))
383 return true
384 }
385 defer requestHijackedConn.Close()
386
387 if backendHTTPResponse.StatusCode != http.StatusSwitchingProtocols {
388
389 klog.V(6).Infof("Proxy upgrade error, status code %d", backendHTTPResponse.StatusCode)
390
391 deadline := time.Now().Add(10 * time.Second)
392 backendConn.SetReadDeadline(deadline)
393 requestHijackedConn.SetWriteDeadline(deadline)
394
395 err := backendHTTPResponse.Write(requestHijackedConn)
396 if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
397 klog.Errorf("Error proxying data from backend to client: %v", err)
398 }
399
400 return true
401 }
402
403
404 if len(rawResponse) > 0 {
405 klog.V(6).Infof("Writing %d bytes to hijacked connection", len(rawResponse))
406 if _, err = requestHijackedConn.Write(rawResponse); err != nil {
407 utilruntime.HandleError(fmt.Errorf("Error proxying response from backend to client: %v", err))
408 }
409 }
410
411
412
413
414
415 writerComplete := make(chan struct{})
416 readerComplete := make(chan struct{})
417
418 go func() {
419 var writer io.WriteCloser
420 if h.MaxBytesPerSec > 0 {
421 writer = flowrate.NewWriter(backendConn, h.MaxBytesPerSec)
422 } else {
423 writer = backendConn
424 }
425 _, err := io.Copy(writer, requestHijackedConn)
426 if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
427 klog.Errorf("Error proxying data from client to backend: %v", err)
428 }
429 close(writerComplete)
430 }()
431
432 go func() {
433 var reader io.ReadCloser
434 if h.MaxBytesPerSec > 0 {
435 reader = flowrate.NewReader(backendConn, h.MaxBytesPerSec)
436 } else {
437 reader = backendConn
438 }
439 _, err := io.Copy(requestHijackedConn, reader)
440 if err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
441 klog.Errorf("Error proxying data from backend to client: %v", err)
442 }
443 close(readerComplete)
444 }()
445
446
447
448 select {
449 case <-writerComplete:
450 case <-readerComplete:
451 }
452 klog.V(6).Infof("Disconnecting from backend proxy %s\n Headers: %v", &location, clone.Header)
453
454 return true
455 }
456
457
458
459 func singleJoiningSlash(a, b string) string {
460 aslash := strings.HasSuffix(a, "/")
461 bslash := strings.HasPrefix(b, "/")
462 switch {
463 case aslash && bslash:
464 return a + b[1:]
465 case !aslash && !bslash:
466 return a + "/" + b
467 }
468 return a + b
469 }
470
471 func (h *UpgradeAwareHandler) DialForUpgrade(req *http.Request) (net.Conn, error) {
472 if h.UpgradeTransport == nil {
473 return dial(req, h.Transport)
474 }
475 updatedReq, err := h.UpgradeTransport.WrapRequest(req)
476 if err != nil {
477 return nil, err
478 }
479 return dial(updatedReq, h.UpgradeTransport)
480 }
481
482
483
484 func getResponse(r io.Reader) (*http.Response, []byte, error) {
485 rawResponse := bytes.NewBuffer(make([]byte, 0, 256))
486
487 resp, err := http.ReadResponse(bufio.NewReader(io.TeeReader(r, rawResponse)), nil)
488 if err != nil {
489 return nil, nil, err
490 }
491
492 return resp, rawResponse.Bytes(), nil
493 }
494
495
496 func dial(req *http.Request, transport http.RoundTripper) (net.Conn, error) {
497 conn, err := DialURL(req.Context(), req.URL, transport)
498 if err != nil {
499 return nil, fmt.Errorf("error dialing backend: %v", err)
500 }
501
502 if err = req.Write(conn); err != nil {
503 conn.Close()
504 return nil, fmt.Errorf("error sending request: %v", err)
505 }
506
507 return conn, err
508 }
509
510 func (h *UpgradeAwareHandler) defaultProxyTransport(url *url.URL, internalTransport http.RoundTripper) http.RoundTripper {
511 scheme := url.Scheme
512 host := url.Host
513 suffix := h.Location.Path
514 if strings.HasSuffix(url.Path, "/") && !strings.HasSuffix(suffix, "/") {
515 suffix += "/"
516 }
517 pathPrepend := strings.TrimSuffix(url.Path, suffix)
518 rewritingTransport := &Transport{
519 Scheme: scheme,
520 Host: host,
521 PathPrepend: pathPrepend,
522 RoundTripper: internalTransport,
523 }
524 return &corsRemovingTransport{
525 RoundTripper: rewritingTransport,
526 }
527 }
528
529
530
531
532 type corsRemovingTransport struct {
533 http.RoundTripper
534 }
535
536 var _ = utilnet.RoundTripperWrapper(&corsRemovingTransport{})
537
538 func (rt *corsRemovingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
539 resp, err := rt.RoundTripper.RoundTrip(req)
540 if err != nil {
541 return nil, err
542 }
543 removeCORSHeaders(resp)
544 return resp, nil
545 }
546
547 func (rt *corsRemovingTransport) WrappedRoundTripper() http.RoundTripper {
548 return rt.RoundTripper
549 }
550
551
552
553 func removeCORSHeaders(resp *http.Response) {
554 resp.Header.Del("Access-Control-Allow-Credentials")
555 resp.Header.Del("Access-Control-Allow-Headers")
556 resp.Header.Del("Access-Control-Allow-Methods")
557 resp.Header.Del("Access-Control-Allow-Origin")
558 }
559
View as plain text