1
16
17 package roundtripper
18
19 import (
20 "context"
21 "crypto/tls"
22 "crypto/x509"
23 "encoding/json"
24 "errors"
25 "fmt"
26 "io"
27 "net"
28 "net/http"
29 "net/http/httputil"
30 "net/url"
31 "regexp"
32
33 "golang.org/x/net/http2"
34
35 "sigs.k8s.io/gateway-api/conformance/utils/config"
36 )
37
38 const (
39 H2CPriorKnowledgeProtocol = "H2C_PRIOR_KNOWLEDGE"
40 )
41
42
43
44 type RoundTripper interface {
45 CaptureRoundTrip(Request) (*CapturedRequest, *CapturedResponse, error)
46 }
47
48
49 type Request struct {
50 URL url.URL
51 Host string
52 Protocol string
53 Method string
54 Headers map[string][]string
55 UnfollowRedirect bool
56 CertPem []byte
57 KeyPem []byte
58 Server string
59 }
60
61
62
63 func (r Request) String() string {
64 return fmt.Sprintf("{URL: %+v, Host: %v, Protocol: %v, Method: %v, Headers: %v, UnfollowRedirect: %v, Server: %v, CertPem: <truncated>, KeyPem: <truncated>}",
65 r.URL,
66 r.Host,
67 r.Protocol,
68 r.Method,
69 r.Headers,
70 r.UnfollowRedirect,
71 r.Server,
72 )
73 }
74
75
76
77 type CapturedRequest struct {
78 Path string `json:"path"`
79 Host string `json:"host"`
80 Method string `json:"method"`
81 Protocol string `json:"proto"`
82 Headers map[string][]string `json:"headers"`
83
84 Namespace string `json:"namespace"`
85 Pod string `json:"pod"`
86 }
87
88
89
90 type RedirectRequest struct {
91 Scheme string
92 Host string
93 Port string
94 Path string
95 }
96
97
98 type CapturedResponse struct {
99 StatusCode int
100 ContentLength int64
101 Protocol string
102 Headers map[string][]string
103 RedirectRequest *RedirectRequest
104 }
105
106
107
108 type DefaultRoundTripper struct {
109 Debug bool
110 TimeoutConfig config.TimeoutConfig
111 CustomDialContext func(context.Context, string, string) (net.Conn, error)
112 }
113
114 func (d *DefaultRoundTripper) httpTransport(request Request) (http.RoundTripper, error) {
115 transport := &http.Transport{
116 DialContext: d.CustomDialContext,
117
118
119
120
121
122
123
124 DisableKeepAlives: true,
125 }
126 if request.Server != "" && len(request.CertPem) != 0 && len(request.KeyPem) != 0 {
127 tlsConfig, err := tlsClientConfig(request.Server, request.CertPem, request.KeyPem)
128 if err != nil {
129 return nil, err
130 }
131 transport.TLSClientConfig = tlsConfig
132 }
133
134 return transport, nil
135 }
136
137 func (d *DefaultRoundTripper) h2cPriorKnowledgeTransport(request Request) (http.RoundTripper, error) {
138 if request.Server != "" && len(request.CertPem) != 0 && len(request.KeyPem) != 0 {
139 return nil, errors.New("request has configured cert and key but h2 prior knowledge is not encrypted")
140 }
141
142 transport := &http2.Transport{
143 AllowHTTP: true,
144 DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
145 var d net.Dialer
146 return d.DialContext(ctx, network, addr)
147 },
148 }
149
150 return transport, nil
151 }
152
153
154
155
156
157 func (d *DefaultRoundTripper) CaptureRoundTrip(request Request) (*CapturedRequest, *CapturedResponse, error) {
158 var transport http.RoundTripper
159 var err error
160
161 switch request.Protocol {
162 case H2CPriorKnowledgeProtocol:
163 transport, err = d.h2cPriorKnowledgeTransport(request)
164 default:
165 transport, err = d.httpTransport(request)
166 }
167
168 if err != nil {
169 return nil, nil, err
170 }
171
172 return d.defaultRoundTrip(request, transport)
173 }
174
175 func (d *DefaultRoundTripper) defaultRoundTrip(request Request, transport http.RoundTripper) (*CapturedRequest, *CapturedResponse, error) {
176 client := &http.Client{}
177
178 if request.UnfollowRedirect {
179 client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
180 return http.ErrUseLastResponse
181 }
182 }
183
184 client.Transport = transport
185
186 method := "GET"
187 if request.Method != "" {
188 method = request.Method
189 }
190 ctx, cancel := context.WithTimeout(context.Background(), d.TimeoutConfig.RequestTimeout)
191 defer cancel()
192 req, err := http.NewRequestWithContext(ctx, method, request.URL.String(), nil)
193 if err != nil {
194 return nil, nil, err
195 }
196
197 if request.Host != "" {
198 req.Host = request.Host
199 }
200
201 if request.Headers != nil {
202 for name, value := range request.Headers {
203 req.Header.Set(name, value[0])
204 }
205 }
206
207 if d.Debug {
208 var dump []byte
209 dump, err = httputil.DumpRequestOut(req, true)
210 if err != nil {
211 return nil, nil, err
212 }
213
214 fmt.Printf("Sending Request:\n%s\n\n", formatDump(dump, "< "))
215 }
216
217 resp, err := client.Do(req)
218 if err != nil {
219 return nil, nil, err
220 }
221 defer resp.Body.Close()
222
223 if d.Debug {
224 var dump []byte
225 dump, err = httputil.DumpResponse(resp, true)
226 if err != nil {
227 return nil, nil, err
228 }
229
230 fmt.Printf("Received Response:\n%s\n\n", formatDump(dump, "< "))
231 }
232
233 cReq := &CapturedRequest{}
234
235 body, err := io.ReadAll(resp.Body)
236 if err != nil {
237 return nil, nil, err
238 }
239
240
241 if resp.Header.Get("Content-type") == "application/json" {
242 err = json.Unmarshal(body, cReq)
243 if err != nil {
244 return nil, nil, fmt.Errorf("unexpected error reading response: %w", err)
245 }
246 } else {
247 cReq.Method = method
248 }
249
250 cRes := &CapturedResponse{
251 StatusCode: resp.StatusCode,
252 ContentLength: resp.ContentLength,
253 Protocol: resp.Proto,
254 Headers: resp.Header,
255 }
256
257 if IsRedirect(resp.StatusCode) {
258 redirectURL, err := resp.Location()
259 if err != nil {
260 return nil, nil, err
261 }
262 cRes.RedirectRequest = &RedirectRequest{
263 Scheme: redirectURL.Scheme,
264 Host: redirectURL.Hostname(),
265 Port: redirectURL.Port(),
266 Path: redirectURL.Path,
267 }
268 }
269
270 return cReq, cRes, nil
271 }
272
273 func tlsClientConfig(server string, certPem []byte, keyPem []byte) (*tls.Config, error) {
274
275 cert, err := tls.X509KeyPair(certPem, keyPem)
276 if err != nil {
277 return nil, fmt.Errorf("unexpected error creating cert: %w", err)
278 }
279
280
281 certPool := x509.NewCertPool()
282 if !certPool.AppendCertsFromPEM(certPem) {
283 return nil, fmt.Errorf("unexpected error adding trusted CA: %w", err)
284 }
285
286 if server == "" {
287 return nil, fmt.Errorf("unexpected error, server name required for TLS")
288 }
289
290
291
292
293 return &tls.Config{
294 Certificates: []tls.Certificate{cert},
295 ServerName: server,
296 RootCAs: certPool,
297 }, nil
298 }
299
300
301 func IsRedirect(statusCode int) bool {
302 switch statusCode {
303 case http.StatusMultipleChoices,
304 http.StatusMovedPermanently,
305 http.StatusFound,
306 http.StatusSeeOther,
307 http.StatusNotModified,
308 http.StatusUseProxy,
309 http.StatusTemporaryRedirect,
310 http.StatusPermanentRedirect:
311 return true
312 }
313 return false
314 }
315
316
317 func IsTimeoutError(statusCode int) bool {
318 switch statusCode {
319 case http.StatusRequestTimeout,
320 http.StatusGatewayTimeout:
321 return true
322 }
323 return false
324 }
325
326 var startLineRegex = regexp.MustCompile(`(?m)^`)
327
328 func formatDump(data []byte, prefix string) string {
329 data = startLineRegex.ReplaceAllLiteral(data, []byte(prefix))
330 return string(data)
331 }
332
View as plain text