...

Source file src/sigs.k8s.io/gateway-api/conformance/utils/roundtripper/roundtripper.go

Documentation: sigs.k8s.io/gateway-api/conformance/utils/roundtripper

     1  /*
     2  Copyright 2022 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package roundtripper
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"crypto/x509"
    23  	"encoding/json"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"net"
    28  	"net/http"
    29  	"net/http/httputil"
    30  	"net/url"
    31  	"regexp"
    32  
    33  	"golang.org/x/net/http2"
    34  
    35  	"sigs.k8s.io/gateway-api/conformance/utils/config"
    36  )
    37  
    38  const (
    39  	H2CPriorKnowledgeProtocol = "H2C_PRIOR_KNOWLEDGE"
    40  )
    41  
    42  // RoundTripper is an interface used to make requests within conformance tests.
    43  // This can be overridden with custom implementations whenever necessary.
    44  type RoundTripper interface {
    45  	CaptureRoundTrip(Request) (*CapturedRequest, *CapturedResponse, error)
    46  }
    47  
    48  // Request is the primary input for making a request.
    49  type Request struct {
    50  	URL              url.URL
    51  	Host             string
    52  	Protocol         string
    53  	Method           string
    54  	Headers          map[string][]string
    55  	UnfollowRedirect bool
    56  	CertPem          []byte
    57  	KeyPem           []byte
    58  	Server           string
    59  }
    60  
    61  // String returns a printable version of Request for logging. Note that the
    62  // CertPem and KeyPem are truncated.
    63  func (r Request) String() string {
    64  	return fmt.Sprintf("{URL: %+v, Host: %v, Protocol: %v, Method: %v, Headers: %v, UnfollowRedirect: %v, Server: %v, CertPem: <truncated>, KeyPem: <truncated>}",
    65  		r.URL,
    66  		r.Host,
    67  		r.Protocol,
    68  		r.Method,
    69  		r.Headers,
    70  		r.UnfollowRedirect,
    71  		r.Server,
    72  	)
    73  }
    74  
    75  // CapturedRequest contains request metadata captured from an echoserver
    76  // response.
    77  type CapturedRequest struct {
    78  	Path     string              `json:"path"`
    79  	Host     string              `json:"host"`
    80  	Method   string              `json:"method"`
    81  	Protocol string              `json:"proto"`
    82  	Headers  map[string][]string `json:"headers"`
    83  
    84  	Namespace string `json:"namespace"`
    85  	Pod       string `json:"pod"`
    86  }
    87  
    88  // RedirectRequest contains a follow up request metadata captured from a redirect
    89  // response.
    90  type RedirectRequest struct {
    91  	Scheme string
    92  	Host   string
    93  	Port   string
    94  	Path   string
    95  }
    96  
    97  // CapturedResponse contains response metadata.
    98  type CapturedResponse struct {
    99  	StatusCode      int
   100  	ContentLength   int64
   101  	Protocol        string
   102  	Headers         map[string][]string
   103  	RedirectRequest *RedirectRequest
   104  }
   105  
   106  // DefaultRoundTripper is the default implementation of a RoundTripper. It will
   107  // be used if a custom implementation is not specified.
   108  type DefaultRoundTripper struct {
   109  	Debug             bool
   110  	TimeoutConfig     config.TimeoutConfig
   111  	CustomDialContext func(context.Context, string, string) (net.Conn, error)
   112  }
   113  
   114  func (d *DefaultRoundTripper) httpTransport(request Request) (http.RoundTripper, error) {
   115  	transport := &http.Transport{
   116  		DialContext: d.CustomDialContext,
   117  		// We disable keep-alives so that we don't leak established TCP connections.
   118  		// Leaking TCP connections is bad because we could eventually hit the
   119  		// threshold of maximum number of open TCP connections to a specific
   120  		// destination. Keep-alives are not presently utilized so disabling this has
   121  		// no adverse affect.
   122  		//
   123  		// Ref. https://github.com/kubernetes-sigs/gateway-api/issues/2357
   124  		DisableKeepAlives: true,
   125  	}
   126  	if request.Server != "" && len(request.CertPem) != 0 && len(request.KeyPem) != 0 {
   127  		tlsConfig, err := tlsClientConfig(request.Server, request.CertPem, request.KeyPem)
   128  		if err != nil {
   129  			return nil, err
   130  		}
   131  		transport.TLSClientConfig = tlsConfig
   132  	}
   133  
   134  	return transport, nil
   135  }
   136  
   137  func (d *DefaultRoundTripper) h2cPriorKnowledgeTransport(request Request) (http.RoundTripper, error) {
   138  	if request.Server != "" && len(request.CertPem) != 0 && len(request.KeyPem) != 0 {
   139  		return nil, errors.New("request has configured cert and key but h2 prior knowledge is not encrypted")
   140  	}
   141  
   142  	transport := &http2.Transport{
   143  		AllowHTTP: true,
   144  		DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
   145  			var d net.Dialer
   146  			return d.DialContext(ctx, network, addr)
   147  		},
   148  	}
   149  
   150  	return transport, nil
   151  }
   152  
   153  // CaptureRoundTrip makes a request with the provided parameters and returns the
   154  // captured request and response from echoserver. An error will be returned if
   155  // there is an error running the function but not if an HTTP error status code
   156  // is received.
   157  func (d *DefaultRoundTripper) CaptureRoundTrip(request Request) (*CapturedRequest, *CapturedResponse, error) {
   158  	var transport http.RoundTripper
   159  	var err error
   160  
   161  	switch request.Protocol {
   162  	case H2CPriorKnowledgeProtocol:
   163  		transport, err = d.h2cPriorKnowledgeTransport(request)
   164  	default:
   165  		transport, err = d.httpTransport(request)
   166  	}
   167  
   168  	if err != nil {
   169  		return nil, nil, err
   170  	}
   171  
   172  	return d.defaultRoundTrip(request, transport)
   173  }
   174  
   175  func (d *DefaultRoundTripper) defaultRoundTrip(request Request, transport http.RoundTripper) (*CapturedRequest, *CapturedResponse, error) {
   176  	client := &http.Client{}
   177  
   178  	if request.UnfollowRedirect {
   179  		client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
   180  			return http.ErrUseLastResponse
   181  		}
   182  	}
   183  
   184  	client.Transport = transport
   185  
   186  	method := "GET"
   187  	if request.Method != "" {
   188  		method = request.Method
   189  	}
   190  	ctx, cancel := context.WithTimeout(context.Background(), d.TimeoutConfig.RequestTimeout)
   191  	defer cancel()
   192  	req, err := http.NewRequestWithContext(ctx, method, request.URL.String(), nil)
   193  	if err != nil {
   194  		return nil, nil, err
   195  	}
   196  
   197  	if request.Host != "" {
   198  		req.Host = request.Host
   199  	}
   200  
   201  	if request.Headers != nil {
   202  		for name, value := range request.Headers {
   203  			req.Header.Set(name, value[0])
   204  		}
   205  	}
   206  
   207  	if d.Debug {
   208  		var dump []byte
   209  		dump, err = httputil.DumpRequestOut(req, true)
   210  		if err != nil {
   211  			return nil, nil, err
   212  		}
   213  
   214  		fmt.Printf("Sending Request:\n%s\n\n", formatDump(dump, "< "))
   215  	}
   216  
   217  	resp, err := client.Do(req)
   218  	if err != nil {
   219  		return nil, nil, err
   220  	}
   221  	defer resp.Body.Close()
   222  
   223  	if d.Debug {
   224  		var dump []byte
   225  		dump, err = httputil.DumpResponse(resp, true)
   226  		if err != nil {
   227  			return nil, nil, err
   228  		}
   229  
   230  		fmt.Printf("Received Response:\n%s\n\n", formatDump(dump, "< "))
   231  	}
   232  
   233  	cReq := &CapturedRequest{}
   234  
   235  	body, err := io.ReadAll(resp.Body)
   236  	if err != nil {
   237  		return nil, nil, err
   238  	}
   239  
   240  	// we cannot assume the response is JSON
   241  	if resp.Header.Get("Content-type") == "application/json" {
   242  		err = json.Unmarshal(body, cReq)
   243  		if err != nil {
   244  			return nil, nil, fmt.Errorf("unexpected error reading response: %w", err)
   245  		}
   246  	} else {
   247  		cReq.Method = method // assume it made the right request if the service being called isn't echoing
   248  	}
   249  
   250  	cRes := &CapturedResponse{
   251  		StatusCode:    resp.StatusCode,
   252  		ContentLength: resp.ContentLength,
   253  		Protocol:      resp.Proto,
   254  		Headers:       resp.Header,
   255  	}
   256  
   257  	if IsRedirect(resp.StatusCode) {
   258  		redirectURL, err := resp.Location()
   259  		if err != nil {
   260  			return nil, nil, err
   261  		}
   262  		cRes.RedirectRequest = &RedirectRequest{
   263  			Scheme: redirectURL.Scheme,
   264  			Host:   redirectURL.Hostname(),
   265  			Port:   redirectURL.Port(),
   266  			Path:   redirectURL.Path,
   267  		}
   268  	}
   269  
   270  	return cReq, cRes, nil
   271  }
   272  
   273  func tlsClientConfig(server string, certPem []byte, keyPem []byte) (*tls.Config, error) {
   274  	// Create a certificate from the provided cert and key
   275  	cert, err := tls.X509KeyPair(certPem, keyPem)
   276  	if err != nil {
   277  		return nil, fmt.Errorf("unexpected error creating cert: %w", err)
   278  	}
   279  
   280  	// Add the provided cert as a trusted CA
   281  	certPool := x509.NewCertPool()
   282  	if !certPool.AppendCertsFromPEM(certPem) {
   283  		return nil, fmt.Errorf("unexpected error adding trusted CA: %w", err)
   284  	}
   285  
   286  	if server == "" {
   287  		return nil, fmt.Errorf("unexpected error, server name required for TLS")
   288  	}
   289  
   290  	// Create the tls Config for this provided host, cert, and trusted CA
   291  	// Disable G402: TLS MinVersion too low. (gosec)
   292  	// #nosec G402
   293  	return &tls.Config{
   294  		Certificates: []tls.Certificate{cert},
   295  		ServerName:   server,
   296  		RootCAs:      certPool,
   297  	}, nil
   298  }
   299  
   300  // IsRedirect returns true if a given status code is a redirect code.
   301  func IsRedirect(statusCode int) bool {
   302  	switch statusCode {
   303  	case http.StatusMultipleChoices,
   304  		http.StatusMovedPermanently,
   305  		http.StatusFound,
   306  		http.StatusSeeOther,
   307  		http.StatusNotModified,
   308  		http.StatusUseProxy,
   309  		http.StatusTemporaryRedirect,
   310  		http.StatusPermanentRedirect:
   311  		return true
   312  	}
   313  	return false
   314  }
   315  
   316  // IsTimeoutError returns true if a given status code is a timeout error code.
   317  func IsTimeoutError(statusCode int) bool {
   318  	switch statusCode {
   319  	case http.StatusRequestTimeout,
   320  		http.StatusGatewayTimeout:
   321  		return true
   322  	}
   323  	return false
   324  }
   325  
   326  var startLineRegex = regexp.MustCompile(`(?m)^`)
   327  
   328  func formatDump(data []byte, prefix string) string {
   329  	data = startLineRegex.ReplaceAllLiteral(data, []byte(prefix))
   330  	return string(data)
   331  }
   332  

View as plain text