...

Source file src/cloud.google.com/go/httpreplay/internal/proxy/replay.go

Documentation: cloud.google.com/go/httpreplay/internal/proxy

     1  // Copyright 2018 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxy
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/json"
    20  	"errors"
    21  	"fmt"
    22  	"log"
    23  	"net/http"
    24  	"os"
    25  	"reflect"
    26  	"sync"
    27  
    28  	"github.com/google/martian/v3/martianlog"
    29  )
    30  
    31  // ForReplaying returns a Proxy configured to replay.
    32  func ForReplaying(filename string, port int) (*Proxy, error) {
    33  	p, err := newProxy(filename)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	lg, err := readLog(filename)
    38  	if err != nil {
    39  		return nil, err
    40  	}
    41  	calls, err := constructCalls(lg)
    42  	if err != nil {
    43  		return nil, err
    44  	}
    45  	p.Initial = lg.Initial
    46  	p.mproxy.SetRoundTripper(&replayRoundTripper{
    47  		calls:         calls,
    48  		ignoreHeaders: p.ignoreHeaders,
    49  		conv:          lg.Converter,
    50  	})
    51  
    52  	// Debug logging.
    53  	// TODO(jba): factor out from here and ForRecording.
    54  	logger := martianlog.NewLogger()
    55  	logger.SetDecode(true)
    56  	p.mproxy.SetRequestModifier(logger)
    57  	p.mproxy.SetResponseModifier(logger)
    58  
    59  	if err := p.start(port); err != nil {
    60  		return nil, err
    61  	}
    62  	return p, nil
    63  }
    64  
    65  func readLog(filename string) (*Log, error) {
    66  	bytes, err := os.ReadFile(filename)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  	var lg Log
    71  	if err := json.Unmarshal(bytes, &lg); err != nil {
    72  		return nil, fmt.Errorf("%s: %v", filename, err)
    73  	}
    74  	if lg.Version != LogVersion {
    75  		return nil, fmt.Errorf(
    76  			"httpreplay: read log version %s but current version is %s; re-record the log",
    77  			lg.Version, LogVersion)
    78  	}
    79  	return &lg, nil
    80  }
    81  
    82  // A call is an HTTP request and its matching response.
    83  type call struct {
    84  	req *Request
    85  	res *Response
    86  }
    87  
    88  func constructCalls(lg *Log) ([]*call, error) {
    89  	ignoreIDs := map[string]bool{} // IDs of requests to ignore
    90  	callsByID := map[string]*call{}
    91  	var calls []*call
    92  	for _, e := range lg.Entries {
    93  		if ignoreIDs[e.ID] {
    94  			continue
    95  		}
    96  		c, ok := callsByID[e.ID]
    97  		switch {
    98  		case !ok:
    99  			if e.Request == nil {
   100  				return nil, fmt.Errorf("first entry for ID %s does not have a request", e.ID)
   101  			}
   102  			if e.Request.Method == "CONNECT" {
   103  				// Ignore CONNECT methods.
   104  				ignoreIDs[e.ID] = true
   105  			} else {
   106  				c := &call{e.Request, e.Response}
   107  				calls = append(calls, c)
   108  				callsByID[e.ID] = c
   109  			}
   110  		case e.Request != nil:
   111  			if e.Response != nil {
   112  				return nil, errors.New("entry has both request and response")
   113  			}
   114  			c.req = e.Request
   115  		case e.Response != nil:
   116  			c.res = e.Response
   117  		default:
   118  			return nil, errors.New("entry has neither request nor response")
   119  		}
   120  	}
   121  	for _, c := range calls {
   122  		if c.req == nil || c.res == nil {
   123  			return nil, fmt.Errorf("missing request or response: %+v", c)
   124  		}
   125  	}
   126  	return calls, nil
   127  }
   128  
   129  type replayRoundTripper struct {
   130  	mu            sync.Mutex
   131  	calls         []*call
   132  	ignoreHeaders map[string]bool
   133  	conv          *Converter
   134  }
   135  
   136  func (r *replayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
   137  	if req.Body != nil {
   138  		defer req.Body.Close()
   139  	}
   140  	creq, err := r.conv.convertRequest(req)
   141  	if err != nil {
   142  		return nil, err
   143  	}
   144  	r.mu.Lock()
   145  	defer r.mu.Unlock()
   146  	for i, call := range r.calls {
   147  		if call == nil {
   148  			continue
   149  		}
   150  		if requestsMatch(creq, call.req, r.ignoreHeaders) {
   151  			r.calls[i] = nil // nil out this call so we don't reuse it
   152  			return toHTTPResponse(call.res, req), nil
   153  		}
   154  	}
   155  	return nil, fmt.Errorf("no matching request for %+v", req)
   156  }
   157  
   158  // Report whether the incoming request in matches the candidate request cand.
   159  func requestsMatch(in, cand *Request, ignoreHeaders map[string]bool) bool {
   160  	if in.Method != cand.Method {
   161  		return false
   162  	}
   163  	if in.URL != cand.URL {
   164  		return false
   165  	}
   166  	if in.MediaType != cand.MediaType {
   167  		return false
   168  	}
   169  	if len(in.BodyParts) != len(cand.BodyParts) {
   170  		return false
   171  	}
   172  	for i, p1 := range in.BodyParts {
   173  		if !bytes.Equal(p1, cand.BodyParts[i]) {
   174  			return false
   175  		}
   176  	}
   177  	// Check headers last. See DebugHeaders.
   178  	return headersMatch(in.Header, cand.Header, ignoreHeaders)
   179  }
   180  
   181  // DebugHeaders helps to determine whether a header should be ignored.
   182  // When true, if requests have the same method, URL and body but differ
   183  // in a header, the first mismatched header is logged.
   184  var DebugHeaders = false
   185  
   186  func headersMatch(in, cand http.Header, ignores map[string]bool) bool {
   187  	for k1, v1 := range in {
   188  		if ignores[k1] {
   189  			continue
   190  		}
   191  		v2 := cand[k1]
   192  		if v2 == nil {
   193  			if DebugHeaders {
   194  				log.Printf("header %s: present in incoming request but not candidate", k1)
   195  			}
   196  			return false
   197  		}
   198  		if !reflect.DeepEqual(v1, v2) {
   199  			if DebugHeaders {
   200  				log.Printf("header %s: incoming %v, candidate %v", k1, v1, v2)
   201  			}
   202  			return false
   203  		}
   204  	}
   205  	for k2 := range cand {
   206  		if ignores[k2] {
   207  			continue
   208  		}
   209  		if in[k2] == nil {
   210  			if DebugHeaders {
   211  				log.Printf("header %s: not in incoming request but present in candidate", k2)
   212  			}
   213  			return false
   214  		}
   215  	}
   216  	return true
   217  }
   218  

View as plain text