1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package httpproxy
16
17 import (
18 "bytes"
19 "context"
20 "fmt"
21 "io"
22 "io/ioutil"
23 "net"
24 "net/http"
25 "net/url"
26 "strings"
27 "sync/atomic"
28 "time"
29
30 "go.etcd.io/etcd/server/v3/etcdserver/api/v2http/httptypes"
31
32 "go.uber.org/zap"
33 )
34
35 var (
36
37
38
39 singleHopHeaders = []string{
40 "Connection",
41 "Keep-Alive",
42 "Proxy-Authenticate",
43 "Proxy-Authorization",
44 "Te",
45 "Trailers",
46 "Transfer-Encoding",
47 "Upgrade",
48 }
49 )
50
51 func removeSingleHopHeaders(hdrs *http.Header) {
52 for _, h := range singleHopHeaders {
53 hdrs.Del(h)
54 }
55 }
56
57 type reverseProxy struct {
58 lg *zap.Logger
59 director *director
60 transport http.RoundTripper
61 }
62
63 func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request) {
64 reportIncomingRequest(clientreq)
65 proxyreq := new(http.Request)
66 *proxyreq = *clientreq
67 startTime := time.Now()
68
69 var (
70 proxybody []byte
71 err error
72 )
73
74 if clientreq.Body != nil {
75 proxybody, err = ioutil.ReadAll(clientreq.Body)
76 if err != nil {
77 msg := fmt.Sprintf("failed to read request body: %v", err)
78 p.lg.Info("failed to read request body", zap.Error(err))
79 e := httptypes.NewHTTPError(http.StatusInternalServerError, "httpproxy: "+msg)
80 if we := e.WriteTo(rw); we != nil {
81 p.lg.Debug(
82 "error writing HTTPError to remote addr",
83 zap.String("remote-addr", clientreq.RemoteAddr),
84 zap.Error(we),
85 )
86 }
87 return
88 }
89 }
90
91
92 proxyreq.Header = make(http.Header)
93 copyHeader(proxyreq.Header, clientreq.Header)
94
95 normalizeRequest(proxyreq)
96 removeSingleHopHeaders(&proxyreq.Header)
97 maybeSetForwardedFor(proxyreq)
98
99 endpoints := p.director.endpoints()
100 if len(endpoints) == 0 {
101 msg := "zero endpoints currently available"
102 reportRequestDropped(clientreq, zeroEndpoints)
103
104
105 p.lg.Info(msg)
106 e := httptypes.NewHTTPError(http.StatusServiceUnavailable, "httpproxy: "+msg)
107 if we := e.WriteTo(rw); we != nil {
108 p.lg.Debug(
109 "error writing HTTPError to remote addr",
110 zap.String("remote-addr", clientreq.RemoteAddr),
111 zap.Error(we),
112 )
113 }
114 return
115 }
116
117 var requestClosed int32
118 completeCh := make(chan bool, 1)
119 closeNotifier, ok := rw.(http.CloseNotifier)
120 ctx, cancel := context.WithCancel(context.Background())
121 proxyreq = proxyreq.WithContext(ctx)
122 defer cancel()
123 if ok {
124 closeCh := closeNotifier.CloseNotify()
125 go func() {
126 select {
127 case <-closeCh:
128 atomic.StoreInt32(&requestClosed, 1)
129 p.lg.Info(
130 "client closed request prematurely",
131 zap.String("remote-addr", clientreq.RemoteAddr),
132 )
133 cancel()
134 case <-completeCh:
135 }
136 }()
137
138 defer func() {
139 completeCh <- true
140 }()
141 }
142
143 var res *http.Response
144
145 for _, ep := range endpoints {
146 if proxybody != nil {
147 proxyreq.Body = ioutil.NopCloser(bytes.NewBuffer(proxybody))
148 }
149 redirectRequest(proxyreq, ep.URL)
150
151 res, err = p.transport.RoundTrip(proxyreq)
152 if atomic.LoadInt32(&requestClosed) == 1 {
153 return
154 }
155 if err != nil {
156 reportRequestDropped(clientreq, failedSendingRequest)
157 p.lg.Info(
158 "failed to direct request",
159 zap.String("url", ep.URL.String()),
160 zap.Error(err),
161 )
162 ep.Failed()
163 continue
164 }
165
166 break
167 }
168
169 if res == nil {
170
171 msg := fmt.Sprintf("unable to get response from %d endpoint(s)", len(endpoints))
172 reportRequestDropped(clientreq, failedGettingResponse)
173 p.lg.Info(msg)
174 e := httptypes.NewHTTPError(http.StatusBadGateway, "httpproxy: "+msg)
175 if we := e.WriteTo(rw); we != nil {
176 p.lg.Debug(
177 "error writing HTTPError to remote addr",
178 zap.String("remote-addr", clientreq.RemoteAddr),
179 zap.Error(we),
180 )
181 }
182 return
183 }
184
185 defer res.Body.Close()
186 reportRequestHandled(clientreq, res, startTime)
187 removeSingleHopHeaders(&res.Header)
188 copyHeader(rw.Header(), res.Header)
189
190 rw.WriteHeader(res.StatusCode)
191 io.Copy(rw, res.Body)
192 }
193
194 func copyHeader(dst, src http.Header) {
195 for k, vv := range src {
196 for _, v := range vv {
197 dst.Add(k, v)
198 }
199 }
200 }
201
202 func redirectRequest(req *http.Request, loc url.URL) {
203 req.URL.Scheme = loc.Scheme
204 req.URL.Host = loc.Host
205 }
206
207 func normalizeRequest(req *http.Request) {
208 req.Proto = "HTTP/1.1"
209 req.ProtoMajor = 1
210 req.ProtoMinor = 1
211 req.Close = false
212 }
213
214 func maybeSetForwardedFor(req *http.Request) {
215 clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
216 if err != nil {
217 return
218 }
219
220
221
222
223 if prior, ok := req.Header["X-Forwarded-For"]; ok {
224 clientIP = strings.Join(prior, ", ") + ", " + clientIP
225 }
226 req.Header.Set("X-Forwarded-For", clientIP)
227 }
228
View as plain text