1
16
17 package proxy
18
19 import (
20 "bytes"
21 "compress/flate"
22 "compress/gzip"
23 "fmt"
24 "io"
25 "net/http"
26 "net/url"
27 "path"
28 "strings"
29
30 "golang.org/x/net/html"
31 "golang.org/x/net/html/atom"
32 "k8s.io/klog/v2"
33
34 "k8s.io/apimachinery/pkg/api/errors"
35 "k8s.io/apimachinery/pkg/util/net"
36 "k8s.io/apimachinery/pkg/util/sets"
37 )
38
39
40
41
42
43 var atomsToAttrs = map[atom.Atom]sets.String{
44 atom.A: sets.NewString("href"),
45 atom.Applet: sets.NewString("codebase"),
46 atom.Area: sets.NewString("href"),
47 atom.Audio: sets.NewString("src"),
48 atom.Base: sets.NewString("href"),
49 atom.Blockquote: sets.NewString("cite"),
50 atom.Body: sets.NewString("background"),
51 atom.Button: sets.NewString("formaction"),
52 atom.Command: sets.NewString("icon"),
53 atom.Del: sets.NewString("cite"),
54 atom.Embed: sets.NewString("src"),
55 atom.Form: sets.NewString("action"),
56 atom.Frame: sets.NewString("longdesc", "src"),
57 atom.Head: sets.NewString("profile"),
58 atom.Html: sets.NewString("manifest"),
59 atom.Iframe: sets.NewString("longdesc", "src"),
60 atom.Img: sets.NewString("longdesc", "src", "usemap"),
61 atom.Input: sets.NewString("src", "usemap", "formaction"),
62 atom.Ins: sets.NewString("cite"),
63 atom.Link: sets.NewString("href"),
64 atom.Object: sets.NewString("classid", "codebase", "data", "usemap"),
65 atom.Q: sets.NewString("cite"),
66 atom.Script: sets.NewString("src"),
67 atom.Source: sets.NewString("src"),
68 atom.Video: sets.NewString("poster", "src"),
69
70
71 }
72
73
74
75 type Transport struct {
76 Scheme string
77 Host string
78 PathPrepend string
79
80 http.RoundTripper
81 }
82
83
84 func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
85
86 forwardedURI := path.Join(t.PathPrepend, req.URL.EscapedPath())
87 if strings.HasSuffix(req.URL.Path, "/") {
88 forwardedURI = forwardedURI + "/"
89 }
90 req.Header.Set("X-Forwarded-Uri", forwardedURI)
91 if len(t.Host) > 0 {
92 req.Header.Set("X-Forwarded-Host", t.Host)
93 }
94 if len(t.Scheme) > 0 {
95 req.Header.Set("X-Forwarded-Proto", t.Scheme)
96 }
97
98 rt := t.RoundTripper
99 if rt == nil {
100 rt = http.DefaultTransport
101 }
102 resp, err := rt.RoundTrip(req)
103
104 if err != nil {
105 return nil, errors.NewServiceUnavailable(fmt.Sprintf("error trying to reach service: %v", err))
106 }
107
108 if redirect := resp.Header.Get("Location"); redirect != "" {
109 targetURL, err := url.Parse(redirect)
110 if err != nil {
111 return nil, errors.NewInternalError(fmt.Errorf("error trying to parse Location header: %v", err))
112 }
113 resp.Header.Set("Location", t.rewriteURL(targetURL, req.URL, req.Host))
114 return resp, nil
115 }
116
117 cType := resp.Header.Get("Content-Type")
118 cType = strings.TrimSpace(strings.SplitN(cType, ";", 2)[0])
119 if cType != "text/html" {
120
121 return resp, nil
122 }
123
124 return t.rewriteResponse(req, resp)
125 }
126
127 var _ = net.RoundTripperWrapper(&Transport{})
128
129 func (rt *Transport) WrappedRoundTripper() http.RoundTripper {
130 return rt.RoundTripper
131 }
132
133
134
135
136 func (t *Transport) rewriteURL(url *url.URL, sourceURL *url.URL, sourceRequestHost string) string {
137
138
139
140
141
142
143
144
145
146
147
148
149 isDifferentHost := url.Host != "" && url.Host != sourceURL.Host && url.Host != sourceRequestHost
150 isRelative := !strings.HasPrefix(url.Path, "/")
151 if isDifferentHost || isRelative {
152 return url.String()
153 }
154
155
156
157 if !(url.Host == sourceRequestHost && t.Scheme == "" && t.Host == "") {
158 url.Scheme = t.Scheme
159 url.Host = t.Host
160 }
161
162 origPath := url.Path
163
164 if strings.HasPrefix(url.Path, t.PathPrepend) {
165 return url.String()
166 }
167 url.Path = path.Join(t.PathPrepend, url.Path)
168 if strings.HasSuffix(origPath, "/") {
169
170 url.Path += "/"
171 }
172
173 return url.String()
174 }
175
176
177
178
179 func rewriteHTML(reader io.Reader, writer io.Writer, urlRewriter func(*url.URL) string) error {
180
181 tokenizer := html.NewTokenizer(reader)
182
183 var err error
184 for err == nil {
185 tokenType := tokenizer.Next()
186 switch tokenType {
187 case html.ErrorToken:
188 err = tokenizer.Err()
189 case html.StartTagToken, html.SelfClosingTagToken:
190 token := tokenizer.Token()
191 if urlAttrs, ok := atomsToAttrs[token.DataAtom]; ok {
192 for i, attr := range token.Attr {
193 if urlAttrs.Has(attr.Key) {
194 url, err := url.Parse(attr.Val)
195 if err != nil {
196
197
198
199 continue
200 }
201 token.Attr[i].Val = urlRewriter(url)
202 }
203 }
204 }
205 _, err = writer.Write([]byte(token.String()))
206 default:
207 _, err = writer.Write(tokenizer.Raw())
208 }
209 }
210 if err != io.EOF {
211 return err
212 }
213 return nil
214 }
215
216
217
218 func (t *Transport) rewriteResponse(req *http.Request, resp *http.Response) (*http.Response, error) {
219 origBody := resp.Body
220 defer origBody.Close()
221
222 newContent := &bytes.Buffer{}
223 var reader io.Reader = origBody
224 var writer io.Writer = newContent
225 encoding := resp.Header.Get("Content-Encoding")
226 switch encoding {
227 case "gzip":
228 var err error
229 reader, err = gzip.NewReader(reader)
230 if err != nil {
231 return nil, fmt.Errorf("errorf making gzip reader: %v", err)
232 }
233 gzw := gzip.NewWriter(writer)
234 defer gzw.Close()
235 writer = gzw
236 case "deflate":
237 var err error
238 reader = flate.NewReader(reader)
239 flw, err := flate.NewWriter(writer, flate.BestCompression)
240 if err != nil {
241 return nil, fmt.Errorf("errorf making flate writer: %v", err)
242 }
243 defer func() {
244 flw.Close()
245 flw.Flush()
246 }()
247 writer = flw
248 case "":
249
250 default:
251
252 klog.Errorf("Proxy encountered encoding %v for text/html; can't understand this so not fixing links.", encoding)
253 return resp, nil
254 }
255
256 urlRewriter := func(targetUrl *url.URL) string {
257 return t.rewriteURL(targetUrl, req.URL, req.Host)
258 }
259 err := rewriteHTML(reader, writer, urlRewriter)
260 if err != nil {
261 klog.Errorf("Failed to rewrite URLs: %v", err)
262 return resp, err
263 }
264
265 resp.Body = io.NopCloser(newContent)
266
267
268 resp.Header.Del("Content-Length")
269 resp.ContentLength = int64(newContent.Len())
270
271 return resp, err
272 }
273
View as plain text