1 package wsproxy
2
3 import (
4 "bufio"
5 "fmt"
6 "io"
7 "net/http"
8 "strings"
9 "time"
10
11 "github.com/gorilla/websocket"
12 "github.com/sirupsen/logrus"
13 "golang.org/x/net/context"
14 )
15
16
17
18
19 var MethodOverrideParam = "method"
20
21
22
23
24 var TokenCookieName = "token"
25
26
27 type RequestMutatorFunc func(incoming *http.Request, outgoing *http.Request) *http.Request
28
29
30 type Proxy struct {
31 h http.Handler
32 logger Logger
33 maxRespBodyBufferBytes int
34 methodOverrideParam string
35 tokenCookieName string
36 requestMutator RequestMutatorFunc
37 headerForwarder func(header string) bool
38 pingInterval time.Duration
39 pingWait time.Duration
40 pongWait time.Duration
41 }
42
43
44 type Logger interface {
45 Warnln(...interface{})
46 Debugln(...interface{})
47 }
48
49 func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
50 if !websocket.IsWebSocketUpgrade(r) {
51 p.h.ServeHTTP(w, r)
52 return
53 }
54 p.proxy(w, r)
55 }
56
57
58 type Option func(*Proxy)
59
60
61
62
63 func WithMaxRespBodyBufferSize(nBytes int) Option {
64 return func(p *Proxy) {
65 p.maxRespBodyBufferBytes = nBytes
66 }
67 }
68
69
70 func WithMethodParamOverride(param string) Option {
71 return func(p *Proxy) {
72 p.methodOverrideParam = param
73 }
74 }
75
76
77 func WithTokenCookieName(param string) Option {
78 return func(p *Proxy) {
79 p.tokenCookieName = param
80 }
81 }
82
83
84 func WithRequestMutator(fn RequestMutatorFunc) Option {
85 return func(p *Proxy) {
86 p.requestMutator = fn
87 }
88 }
89
90
91 func WithForwardedHeaders(fn func(header string) bool) Option {
92 return func(p *Proxy) {
93 p.headerForwarder = fn
94 }
95 }
96
97
98 func WithLogger(logger Logger) Option {
99 return func(p *Proxy) {
100 p.logger = logger
101 }
102 }
103
104
105
106
107 func WithPingControl(interval time.Duration) Option {
108 return func(proxy *Proxy) {
109 proxy.pingInterval = interval
110 proxy.pongWait = (interval * 10) / 9
111 proxy.pingWait = proxy.pongWait / 6
112 }
113 }
114
115 var defaultHeadersToForward = map[string]bool{
116 "Origin": true,
117 "origin": true,
118 "Referer": true,
119 "referer": true,
120 }
121
122 func defaultHeaderForwarder(header string) bool {
123 return defaultHeadersToForward[header]
124 }
125
126
127
128
129
130
131
132
133
134
135
136
137
138 func WebsocketProxy(h http.Handler, opts ...Option) http.Handler {
139 p := &Proxy{
140 h: h,
141 logger: logrus.New(),
142 methodOverrideParam: MethodOverrideParam,
143 tokenCookieName: TokenCookieName,
144 headerForwarder: defaultHeaderForwarder,
145 }
146 for _, o := range opts {
147 o(p)
148 }
149 return p
150 }
151
152
153 var upgrader = websocket.Upgrader{
154 ReadBufferSize: 1024,
155 WriteBufferSize: 1024,
156 CheckOrigin: func(r *http.Request) bool { return true },
157 }
158
159 func isClosedConnError(err error) bool {
160 str := err.Error()
161 if strings.Contains(str, "use of closed network connection") {
162 return true
163 }
164 return websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway)
165 }
166
167 func (p *Proxy) proxy(w http.ResponseWriter, r *http.Request) {
168 var responseHeader http.Header
169
170
171 if strings.HasPrefix(r.Header.Get("Sec-WebSocket-Protocol"), "Bearer") {
172 responseHeader = http.Header{
173 "Sec-WebSocket-Protocol": []string{"Bearer"},
174 }
175 }
176 conn, err := upgrader.Upgrade(w, r, responseHeader)
177 if err != nil {
178 p.logger.Warnln("error upgrading websocket:", err)
179 return
180 }
181 defer conn.Close()
182
183 ctx, cancelFn := context.WithCancel(context.Background())
184 defer cancelFn()
185
186 requestBodyR, requestBodyW := io.Pipe()
187 request, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), requestBodyR)
188 if err != nil {
189 p.logger.Warnln("error preparing request:", err)
190 return
191 }
192 if swsp := r.Header.Get("Sec-WebSocket-Protocol"); swsp != "" {
193 request.Header.Set("Authorization", transformSubProtocolHeader(swsp))
194 }
195 for header := range r.Header {
196 if p.headerForwarder(header) {
197 request.Header.Set(header, r.Header.Get(header))
198 }
199 }
200
201 if cookie, err := r.Cookie(p.tokenCookieName); err == nil {
202 request.Header.Set("Authorization", "Bearer "+cookie.Value)
203 }
204 if m := r.URL.Query().Get(p.methodOverrideParam); m != "" {
205 request.Method = m
206 }
207
208 if p.requestMutator != nil {
209 request = p.requestMutator(r, request)
210 }
211
212 responseBodyR, responseBodyW := io.Pipe()
213 response := newInMemoryResponseWriter(responseBodyW)
214 go func() {
215 <-ctx.Done()
216 p.logger.Debugln("closing pipes")
217 requestBodyW.CloseWithError(io.EOF)
218 responseBodyW.CloseWithError(io.EOF)
219 response.closed <- true
220 }()
221
222 go func() {
223 defer cancelFn()
224 p.h.ServeHTTP(response, request)
225 }()
226
227
228 go func() {
229 if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
230 conn.SetReadDeadline(time.Now().Add(p.pongWait))
231 conn.SetPongHandler(func(string) error { conn.SetReadDeadline(time.Now().Add(p.pongWait)); return nil })
232 }
233 defer func() {
234 cancelFn()
235 }()
236 for {
237 select {
238 case <-ctx.Done():
239 p.logger.Debugln("read loop done")
240 return
241 default:
242 }
243 p.logger.Debugln("[read] reading from socket.")
244 _, payload, err := conn.ReadMessage()
245 if err != nil {
246 if isClosedConnError(err) {
247 p.logger.Debugln("[read] websocket closed:", err)
248 return
249 }
250 p.logger.Warnln("error reading websocket message:", err)
251 return
252 }
253 p.logger.Debugln("[read] read payload:", string(payload))
254 p.logger.Debugln("[read] writing to requestBody:")
255 n, err := requestBodyW.Write(payload)
256 requestBodyW.Write([]byte("\n"))
257 p.logger.Debugln("[read] wrote to requestBody", n)
258 if err != nil {
259 p.logger.Warnln("[read] error writing message to upstream http server:", err)
260 return
261 }
262 }
263 }()
264
265 if p.pingInterval > 0 && p.pingWait > 0 && p.pongWait > 0 {
266 go func() {
267 ticker := time.NewTicker(p.pingInterval)
268 defer func() {
269 ticker.Stop()
270 conn.Close()
271 }()
272 for {
273 select {
274 case <-ctx.Done():
275 p.logger.Debugln("ping loop done")
276 return
277 case <-ticker.C:
278 conn.SetWriteDeadline(time.Now().Add(p.pingWait))
279 if err := conn.WriteMessage(websocket.PingMessage, nil); err != nil {
280 return
281 }
282 }
283 }
284 }()
285 }
286
287 scanner := bufio.NewScanner(responseBodyR)
288
289
290 var scannerBuf []byte
291 if p.maxRespBodyBufferBytes > 0 {
292 scannerBuf = make([]byte, 0, 64*1024)
293 scanner.Buffer(scannerBuf, p.maxRespBodyBufferBytes)
294 }
295
296 for scanner.Scan() {
297 if len(scanner.Bytes()) == 0 {
298 p.logger.Warnln("[write] empty scan", scanner.Err())
299 continue
300 }
301 p.logger.Debugln("[write] scanned", scanner.Text())
302 if err = conn.WriteMessage(websocket.TextMessage, scanner.Bytes()); err != nil {
303 p.logger.Warnln("[write] error writing websocket message:", err)
304 return
305 }
306 }
307 if err := scanner.Err(); err != nil {
308 p.logger.Warnln("scanner err:", err)
309 }
310 }
311
312 type inMemoryResponseWriter struct {
313 io.Writer
314 header http.Header
315 code int
316 closed chan bool
317 }
318
319 func newInMemoryResponseWriter(w io.Writer) *inMemoryResponseWriter {
320 return &inMemoryResponseWriter{
321 Writer: w,
322 header: http.Header{},
323 closed: make(chan bool, 1),
324 }
325 }
326
327
328 func transformSubProtocolHeader(header string) string {
329 tokens := strings.SplitN(header, "Bearer,", 2)
330
331 if len(tokens) < 2 {
332 return ""
333 }
334
335 return fmt.Sprintf("Bearer %v", strings.Trim(tokens[1], " "))
336 }
337
338 func (w *inMemoryResponseWriter) Write(b []byte) (int, error) {
339 return w.Writer.Write(b)
340 }
341 func (w *inMemoryResponseWriter) Header() http.Header {
342 return w.header
343 }
344 func (w *inMemoryResponseWriter) WriteHeader(code int) {
345 w.code = code
346 }
347 func (w *inMemoryResponseWriter) CloseNotify() <-chan bool {
348 return w.closed
349 }
350 func (w *inMemoryResponseWriter) Flush() {}
351
View as plain text